diff --git a/.github/FUNDING.yml b/.github/FUNDING.yml new file mode 100644 index 0000000..efb5abf --- /dev/null +++ b/.github/FUNDING.yml @@ -0,0 +1,12 @@ +# These are supported funding model platforms + +github: [chewxy, owulveryck, dcu] # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2] +patreon: # Replace with a single Patreon username +open_collective: # Replace with a single Open Collective username +ko_fi: # Replace with a single Ko-fi username +tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel +community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry +liberapay: # Replace with a single Liberapay username +issuehunt: # Replace with a single IssueHunt username +otechie: # Replace with a single Otechie username +custom: # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2'] diff --git a/.github/workflows/.go.yml b/.github/workflows/.go.yml new file mode 100644 index 0000000..9710e2f --- /dev/null +++ b/.github/workflows/.go.yml @@ -0,0 +1,109 @@ +on: + push: + branches: [ master ] + pull_request: +name: test and build +env: + GOPROXY: "https://proxy.golang.org" + CI_NO_PYTHON: "true" +jobs: + test: + strategy: + matrix: + go: [1.18.x, 1.17.x, 1.16.x, 1.15.x] + os: [ubuntu-latest, macos-latest, windows-latest] + tags: [avx, sse] + allowfail: [false] + include: + - go: tip + os: ubuntu-latest + allowfail: true + runs-on: ${{ matrix.os }} + continue-on-error: ${{ matrix.allowfail }} + timeout-minutes: 5 + steps: + - name: Install Go ${{ matrix.go }} on ${{ matrix.os }} + if: matrix.go != 'tip' + uses: actions/setup-go@v2 + with: + go-version: ${{ matrix.go }} +# tempoary hack: +# https://github.com/actions/setup-go/issues/21#issuecomment-565704236 + - name: Install Go ${{ matrix.go }} on ${{ matrix.os }} + if: matrix.go == 'tip' + run: | + git clone --depth=1 https://go.googlesource.com/go $HOME/gotip + cd $HOME/gotip/src + ./make.bash + echo "GOROOT=$HOME/gotip" >> $GITHUB_ENV + echo "$HOME/gotip/bin" >> $GITHUB_PATH + - name: Checkout code + uses: actions/checkout@v2 + - name: Run tests + run: | + go test ./... -v -race + go test ./... -race -tags=${{ matrix.tags }} + + coverage: + env: + CI_NO_PYTHON: "false" + PYTHON_COMMAND: python + strategy: + matrix: + tags: [avx, sse] + runs-on: ubuntu-latest + steps: + - name: Install Go + uses: actions/setup-go@v2 + with: + go-version: 1.14.x + - name: Install Python + uses: actions/setup-python@v2 + with: + python-version: '3.x' + architecture: 'x64' + - name: Install Pip + uses: BSFishy/pip-action@v1 + with: + packages: numpy + - name: Checkout code + uses: actions/checkout@v2 + - name: Calc coverage + run: | + export PATH=$PATH:$(go env GOPATH)/bin + go test ./... -v -covermode=atomic -coverprofile=coverage.out + - name: Convert coverage to lcov + uses: jandelgado/gcov2lcov-action@v1.0.0 + with: + infile: coverage.out + outfile: coverage.lcov + - name: Coveralls + uses: coverallsapp/github-action@v1.0.1 + with: + github-token: ${{ secrets.github_token }} + path-to-lcov: coverage.lcov + + build: + strategy: + matrix: + go: [1.13, 1.14] + goos: [linux, darwin] + goarch: [amd64, arm] + exclude: + # windows/386 and darwin/386 seems useless + - goarch: "arm" + goos: darwin + runs-on: ubuntu-latest + needs: [test] + steps: + - name: Install Go ${{ matrix.go }} + uses: actions/setup-go@v2 + with: + go-version: ${{ matrix.go }} + - name: Checkout code + uses: actions/checkout@v2 + - name: build + run: go build . + env: + GOOS: ${{ matrix.goos }} + GOARCH: ${{ matrix.goarch }} diff --git a/.travis.yml b/.travis.yml deleted file mode 100644 index 1f1604d..0000000 --- a/.travis.yml +++ /dev/null @@ -1,31 +0,0 @@ -sudo: required -language: go - -branches: - only: - - master -go: - - 1.6.x - - 1.7.x - - 1.8.x - - 1.9.x - - tip - -env: - global: - - GOARCH=amd64 - - BLAS_LIB=OpenBLAS - - TRAVISTEST=true - - CUDA=8.0.61-1 - -before_install: - - go get github.com/mattn/goveralls - -go_import_path: gorgonia.org/tensor - -script: - - $HOME/gopath/bin/goveralls -service=travis-ci -package=gorgonia.org/tensor -covermode=atomic - -matrix: - allow_failures: - - go: tip \ No newline at end of file diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index 9e473c5..d94f24c 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -3,6 +3,8 @@ * Xuanyi Chew (@chewxy) - initial package * Naseer Dari (@ndari) - errors and error handling * Joe Kabaka (@kabaka0) - masked array functionality +* Stuart Carnie (@stuartcarnie) - performance optimization for iterators +* Jorge Landivar (@docmerlin) - performance optimization for `*Dense` # Contributors @@ -12,8 +14,8 @@ * David Soller | @3ygun * Davor Kapsa | @dvrkps * James Michael DuPont | @h4ck3rm1k3 -* Jorge Landivar | @docmerlin * Yuanlin Lian | @alienchow +* Andrew SnodGrass | @pointlander diff --git a/Gopkg.lock b/Gopkg.lock deleted file mode 100644 index cb5367f..0000000 --- a/Gopkg.lock +++ /dev/null @@ -1,75 +0,0 @@ -# This file is autogenerated, do not edit; changes may be undone by the next 'dep ensure'. - - -[[projects]] - name = "github.com/chewxy/hm" - packages = ["."] - revision = "61efb3290a086d1335e8954b3734c102126818ba" - version = "v1.0.0" - -[[projects]] - name = "github.com/chewxy/math32" - packages = ["."] - revision = "d1e7b22839c693f54edf7811dd9487623abf2cd2" - version = "v1.0.0" - -[[projects]] - branch = "master" - name = "gorgonia.org/vecf32" - packages = ["."] - revision = "1f59516136c1a7f1c19871d3dc5f0d9928ffbd7c" - -[[projects]] - branch = "master" - name = "gorgonia.org/vecf64" - packages = ["."] - revision = "a97a4d31b6c9343b1860ef8ce583069671265b81" - -[[projects]] - name = "github.com/davecgh/go-spew" - packages = ["spew"] - revision = "346938d642f2ec3594ed81d874461961cd0faa76" - version = "v1.1.0" - -[[projects]] - name = "github.com/pkg/errors" - packages = ["."] - revision = "645ef00459ed84a119197bfb8d8205042c6df63d" - version = "v0.8.0" - -[[projects]] - name = "github.com/pmezard/go-difflib" - packages = ["difflib"] - revision = "792786c7400a136282c1664665ae0a8db921c6c2" - version = "v1.0.0" - -[[projects]] - name = "github.com/stretchr/testify" - packages = ["assert"] - revision = "69483b4bd14f5845b5a1e55bca19e954e827f1d0" - version = "v1.1.4" - -[[projects]] - branch = "master" - name = "github.com/xtgo/set" - packages = ["."] - revision = "4431f6b51265b1e0b76af4dafc09d6f12c2bdcd0" - -[[projects]] - branch = "master" - name = "golang.org/x/tools" - packages = ["go/ast/astutil","go/buildutil","go/loader"] - revision = "e531a2a1c15f94033f6fa87666caeb19a688175f" - -[[projects]] - branch = "master" - name = "gonum.org/v1/gonum" - packages = ["blas","blas/blas64","blas/gonum","floats","internal/asm/c128","internal/asm/f32","internal/asm/f64","internal/math32","lapack","lapack/gonum","lapack/lapack64","mat"] - revision = "f818f8f7a9e59de54e475b747a3dc9c86ed141f1" - -[solve-meta] - analyzer-name = "dep" - analyzer-version = 1 - inputs-digest = "eee289039fc6a17513fe715d3457e0997e6bffa8a27d57987d15cb1a17407705" - solver-name = "gps-cdcl" - solver-version = 1 diff --git a/Gopkg.toml b/Gopkg.toml deleted file mode 100644 index 1a52f27..0000000 --- a/Gopkg.toml +++ /dev/null @@ -1,54 +0,0 @@ - -# Gopkg.toml example -# -# Refer to https://github.com/golang/dep/blob/master/docs/Gopkg.toml.md -# for detailed Gopkg.toml documentation. -# -# required = ["github.com/user/thing/cmd/thing"] -# ignored = ["github.com/user/project/pkgX", "bitbucket.org/user/project/pkgA/pkgY"] -# -# [[constraint]] -# name = "github.com/user/project" -# version = "1.0.0" -# -# [[constraint]] -# name = "github.com/user/project2" -# branch = "dev" -# source = "github.com/myfork/project2" -# -# [[override]] -# name = "github.com/x/y" -# version = "2.4.0" - - -[[constraint]] - name = "github.com/chewxy/hm" - version = "~1.0.0" - -[[constraint]] - name = "github.com/chewxy/math32" - version = "~1.0.0" - -[[constraint]] - branch = "master" - name = "gorgonia.org/vecf32" - -[[constraint]] - branch = "master" - name = "gorgonia.org/vecf64" - -[[constraint]] - name = "github.com/pkg/errors" - version = "~0.8.0" - -[[constraint]] - name = "github.com/stretchr/testify" - version = "1.1.4" - -[[constraint]] - branch = "master" - name = "golang.org/x/tools" - -[[constraint]] - branch = "master" - name = "gonum.org/v1/gonum" diff --git a/LICENCE b/LICENCE index 7bec963..c7a1c7b 100644 --- a/LICENCE +++ b/LICENCE @@ -1,189 +1,202 @@ -The Gorgonia Licence - -Copyright (c) 2016 Xuanyi Chew - -Licensed under the Gorgonia License, Version 1.0 (the "License"); -you may not use this file except in compliance with the License. - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. - -TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION - -1. Definitions. - - "License" shall mean the terms and conditions for use, reproduction, - and distribution as defined by Sections 1 through 9 of this document. - - "Licensor" shall mean the copyright owner or entity authorized by - the copyright owner that is granting the License. - - "Legal Entity" shall mean the union of the acting entity and all - other entities that control, are controlled by, or are under common - control with that entity. For the purposes of this definition, - "control" means (i) the power, direct or indirect, to cause the - direction or management of such entity, whether by contract or - otherwise, or (ii) ownership of fifty percent (50%) or more of the - outstanding shares, or (iii) beneficial ownership of such entity. - - "You" (or "Your") shall mean an individual or Legal Entity - exercising permissions granted by this License. - - "Source" form shall mean the preferred form for making modifications, - including but not limited to software source code, documentation - source, and configuration files. - - "Object" form shall mean any form resulting from mechanical - transformation or translation of a Source form, including but - not limited to compiled object code, generated documentation, - and conversions to other media types. - - "Work" shall mean the work of authorship, whether in Source or - Object form, made available under the License, as indicated by a - copyright notice that is included in or attached to the work - (an example is provided in the Appendix below). - - "Derivative Works" shall mean any work, whether in Source or Object - form, that is based on (or derived from) the Work and for which the - editorial revisions, annotations, elaborations, or other modifications - represent, as a whole, an original work of authorship. For the purposes - of this License, Derivative Works shall not include works that remain - separable from, or merely link (or bind by name) to the interfaces of, - the Work and Derivative Works thereof. - - "Contribution" shall mean any work of authorship, including - the original version of the Work and any modifications or additions - to that Work or Derivative Works thereof, that is intentionally - submitted to Licensor for inclusion in the Work by the copyright owner - or by an individual or Legal Entity authorized to submit on behalf of - the copyright owner. For the purposes of this definition, "submitted" - means any form of electronic, verbal, or written communication sent - to the Licensor or its representatives, including but not limited to - communication on electronic mailing lists, source code control systems, - and issue tracking systems that are managed by, or on behalf of, the - Licensor for the purpose of discussing and improving the Work, but - excluding communication that is conspicuously marked or otherwise - designated in writing by the copyright owner as "Not a Contribution." - - "Significant Contribution" shall mean any Contribution that indicates a deep - understanding of the Work and/or its Derivatives thereof. - - "Contributor" shall mean Licensor and any individual or Legal Entity - on behalf of whom a Contribution has been received by Licensor and - subsequently incorporated within the Work. - -2. Grant of Copyright License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - copyright license to reproduce, prepare Derivative Works of, - publicly display, publicly perform, sublicense, and distribute the - Work and such Derivative Works in Source or Object form. You are not permitted - to directly commercially profit from this Work unless You are also a - Significant Contributor, which is listed under the Contributors list. - -3. Grant of Patent License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - (except as stated in this section) patent license to make, have made, - use, offer to sell, sell, import, and otherwise transfer the Work, - where such license applies only to those patent claims licensable - by such Contributor that are necessarily infringed by their - Contribution(s) alone or by combination of their Contribution(s) - with the Work to which such Contribution(s) was submitted. If You - institute patent litigation against any entity (including a - cross-claim or counterclaim in a lawsuit) alleging that the Work - or a Contribution incorporated within the Work constitutes direct - or contributory patent infringement, then any patent licenses - granted to You under this License for that Work shall terminate - as of the date such litigation is filed. - -4. Redistribution. You may reproduce and distribute copies of the - Work or Derivative Works thereof in any medium, with or without - modifications, and in Source or Object form, provided that You - meet the following conditions: - - (a) You must give any other recipients of the Work or - Derivative Works a copy of this License; and - - (b) You must cause any modified files to carry prominent notices - stating that You changed the files; and - - (c) You must retain, in the Source form of any Derivative Works - that You distribute, all copyright, patent, trademark, and - attribution notices from the Source form of the Work, - excluding those notices that do not pertain to any part of - the Derivative Works; and - - (d) If the Work includes a "NOTICE" text file as part of its - distribution, then any Derivative Works that You distribute must - include a readable copy of the attribution notices contained - within such NOTICE file, excluding those notices that do not - pertain to any part of the Derivative Works, in at least one - of the following places: within a NOTICE text file distributed - as part of the Derivative Works; within the Source form or - documentation, if provided along with the Derivative Works; or, - within a display generated by the Derivative Works, if and - wherever such third-party notices normally appear. The contents - of the NOTICE file are for informational purposes only and - do not modify the License. You may add Your own attribution - notices within Derivative Works that You distribute, alongside - or as an addendum to the NOTICE text from the Work, provided - that such additional attribution notices cannot be construed - as modifying the License. - - You may add Your own copyright statement to Your modifications and - may provide additional or different license terms and conditions - for use, reproduction, or distribution of Your modifications, or - for any such Derivative Works as a whole, provided Your use, - reproduction, and distribution of the Work otherwise complies with - the conditions stated in this License. - -5. Submission of Contributions. Unless You explicitly state otherwise, - any Contribution intentionally submitted for inclusion in the Work - by You to the Licensor shall be under the terms and conditions of - this License, without any additional terms or conditions. - Notwithstanding the above, nothing herein shall supersede or modify - the terms of any separate license agreement you may have executed - with Licensor regarding such Contributions. - -6. Trademarks. This License does not grant permission to use the trade - names, trademarks, service marks, or product names of the Licensor, - except as required for reasonable and customary use in describing the - origin of the Work and reproducing the content of the NOTICE file. - -7. Disclaimer of Warranty. Unless required by applicable law or - agreed to in writing, Licensor provides the Work (and each - Contributor provides its Contributions) on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or - implied, including, without limitation, any warranties or conditions - of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A - PARTICULAR PURPOSE. You are solely responsible for determining the - appropriateness of using or redistributing the Work and assume any - risks associated with Your exercise of permissions under this License. - -8. Limitation of Liability. In no event and under no legal theory, - whether in tort (including negligence), contract, or otherwise, - unless required by applicable law (such as deliberate and grossly - negligent acts) or agreed to in writing, shall any Contributor be - liable to You for damages, including any direct, indirect, special, - incidental, or consequential damages of any character arising as a - result of this License or out of the use or inability to use the - Work (including but not limited to damages for loss of goodwill, - work stoppage, computer failure or malfunction, or any and all - other commercial damages or losses), even if such Contributor - has been advised of the possibility of such damages. - -9. Accepting Warranty or Additional Liability. While redistributing - the Work or Derivative Works thereof, You may choose to offer, - and charge a fee for, acceptance of support, warranty, indemnity, - or other liability obligations and/or rights consistent with this - License. However, in accepting such obligations, You may act only - on Your own behalf and on Your sole responsibility, not on behalf - of any other Contributor, and only if You agree to indemnify, - defend, and hold each Contributor harmless for any liability - incurred by, or claims asserted against, such Contributor by reason - of your accepting any such warranty or additional liability. -END OF TERMS AND CONDITIONS + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright 2019 Gorgonia Authors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/README.md b/README.md index 69d8b06..98108c8 100644 --- a/README.md +++ b/README.md @@ -1,19 +1,20 @@ -# Package `tensor` [![GoDoc](https://godoc.org/github.com/chewxy/gorgonia/tensor?status.svg)](https://godoc.org/github.com/chewxy/gorgonia/tensor) [![Build Status](https://travis-ci.org/gorgonia/tensor.svg?branch=master)](https://travis-ci.org/gorgonia/tensor) [![Coverage Status](https://coveralls.io/repos/github/gorgonia/tensor/badge.svg?branch=master)](https://coveralls.io/github/gorgonia/tensor?branch=master) # +# Package `tensor` [![GoDoc](https://godoc.org/gorgonia.org/tensor?status.svg)](https://godoc.org/gorgonia.org/tensor) [![GitHub version](https://badge.fury.io/gh/gorgonia%2Ftensor.svg)](https://badge.fury.io/gh/gorgonia%2Ftensor) [![Build Status](https://travis-ci.org/gorgonia/tensor.svg?branch=master)](https://travis-ci.org/gorgonia/tensor) [![Coverage Status](https://coveralls.io/repos/github/gorgonia/tensor/badge.svg?branch=master)](https://coveralls.io/github/gorgonia/tensor?branch=master) [![Go Report Card](https://goreportcard.com/badge/gorgonia.org/tensor)](https://goreportcard.com/report/gorgonia.org/tensor) [![unstable](http://badges.github.io/stability-badges/dist/unstable.svg)](http://github.com/badges/stability-badges)# + Package `tensor` is a package that provides efficient, generic (by some definitions of generic) n-dimensional arrays in Go. Also in this package are functions and methods that are used commonly in arithmetic, comparison and linear algebra operations. -The main purpose of this package is to support the operations required by [Gorgonia](https://github.com/chewxy/gorgonia). +The main purpose of this package is to support the operations required by [Gorgonia](https://gorgonia.org/gorgonia). ## Introduction ## -In the data analysis world, [Numpy](http://http://www.numpy.org/) and [Matlab](https://www.mathworks.com/products/matlab.html) currently reign supreme. Both tools rely heavily on having performant n-dimensional arrays, or tensors. **There is an obvious need for multidimensional arrays in Go**. +In the data analysis world, [Numpy](http://http://www.numpy.org/) and [Matlab](https://www.mathworks.com/products/matlab.html) currently reign supreme. Both tools rely heavily on having performant n-dimensional arrays, or tensors. **There is an obvious need for multidimensional arrays in Go**. While slices are cool, a large majority of scientific and numeric computing work relies heavily on matrices (two-dimensional arrays), three dimensional arrays and so on. In Go, the typical way of getting multidimensional arrays is to use something like `[][]T`. Applications that are more math heavy may opt to use the very excellent Gonum [`matrix` package](https://github.com/gonum/matrix). What then if we want to go beyond having a `float64` matrix? What if we wanted a 3-dimensional `float32` array? -It comes to reason then there should be a data structure that handles these things. The `tensor` package fits in that niche. +It comes to reason then there should be a data structure that handles these things. The `tensor` package fits in that niche. ### Basic Idea: Tensor ### A tensor is a multidimensional array. It's like a slice, but works in multiple dimensions. -With slices, there are usage patterns that are repeated enough that warrant abstraction - `append`, `len`, `cap`, `range` are abstrations used to manipulate and query slices. Additionally slicing operations (`a[:1]` for example) are also abstractions provided by the language. Andrew Gerrand wrote a very good write up on [Go's slice usage and internals](https://blog.golang.org/go-slices-usage-and-internals). +With slices, there are usage patterns that are repeated enough that warrant abstraction - `append`, `len`, `cap`, `range` are abstractions used to manipulate and query slices. Additionally slicing operations (`a[:1]` for example) are also abstractions provided by the language. Andrew Gerrand wrote a very good write up on [Go's slice usage and internals](https://blog.golang.org/go-slices-usage-and-internals). Tensors come with their own set of usage patterns and abstractions. Most of these have analogues in slices, enumerated below (do note that certain slice operation will have more than one tensor analogue - this is due to the number of options available): @@ -25,7 +26,7 @@ Tensors come with their own set of usage patterns and abstractions. Most of thes | `a[0]` | `T.At(x,y)` | | `append(a, ...)`| `T.Stack(...)`, `T.Concat(...)` | | `copy(dest, src)`| `T.CopyTo(dest)`, `tensor.Copy(dest, src)` | -| `for _, v := range a` | `for i, err := iterator.Next(); err == nil; i, err = iterator.Next()` | +| `for _, v := range a` | `for i, err := iterator.Next(); err == nil; i, err = iterator.Next()` | Some operations for a tensor does not have direct analogues to slice operations. However, they stem from the same idea, and can be considered a superset of all operations common to slices. They're enumerated below: @@ -50,15 +51,15 @@ The `*Dense` tensor is the primary tensor and is represented by a singular flat ### Compressed Sparse Column Matrix ### -Coming soon +Documentation Coming soon ### Compressed Sparse Row Matrix ### -Coming soon +Documentation Coming soon ## Usage ## -To install: `go get -u "github.com/chewxy/gorgonia/tensor"` +To install: `go get -u "gorgonia.org/tensor"` To create a matrix with package `tensor` is easy: @@ -76,7 +77,7 @@ fmt.Printf("a:\n%v\n", a) To create a 3-Tensor is just as easy - just put the correct shape and you're good to go: -```go +```go // Creating a (2,3,4) 3-Tensor of float32 b := New(WithBacking(Range(Float32, 0, 24)), WithShape(2, 3, 4)) fmt.Printf("b:\n%1.1f\n", b) @@ -129,7 +130,13 @@ b.SetAt(1000, 0, 1, 2) fmt.Printf("b:\n%v", b) ``` -There is a whole laundry list of methods and functions available at the [godoc](https://godoc.org/github.com/chewxy/gorgonia/tensor) page +There is a whole laundry list of methods and functions available at the [godoc](https://godoc.org/gorgonia.org/tensor) page + + +## API Notes ## + +This package has a notion of "layers" in its API. This section clarifies the different patterns seen in the API. + @@ -141,7 +148,7 @@ The design of the `*Dense` tensor is quite simple in concept. However, let's sta The data structure for `*Dense` is similar, but a lot more complex. Much of the complexity comes from the need to do accounting work on the data structure as well as preserving references to memory locations. This is how the `*Dense` is defined: -```go +```go type Dense struct { *AP array @@ -155,7 +162,7 @@ And here's a visual representation of the `*Dense`. ![dense](https://github.com/gorgonia/tensor/blob/master/media/dense.png?raw=true) -`*Dense` draws its inspiration from Go's slice. Underlying it all is a flat array, and access to elements are controlled by `*AP`. Where a Go is able to store its metadata in a 3-word stucture (obiviating the need to allocate memory), a `*Dense` unfortunately needs to allocate some memory. The majority of the data is stored in the `*AP` structure, which contains metadata such as shape, stride, and methods for accessing the array. +`*Dense` draws its inspiration from Go's slice. Underlying it all is a flat array, and access to elements are controlled by `*AP`. Where a Go is able to store its metadata in a 3-word structure (obviating the need to allocate memory), a `*Dense` unfortunately needs to allocate some memory. The majority of the data is stored in the `*AP` structure, which contains metadata such as shape, stride, and methods for accessing the array. `*Dense` embeds an `array` (not to be confused with Go's array), which is an abstracted data structure that looks like this: @@ -167,7 +174,7 @@ type array struct { } ``` -`*storage.Header` is the same structure as `reflect.SliceHeader`, except it stores a `unsafe.Pointer` instead of a `uintptr`. This is done so that eventually when more tests are done to determine how the garbage collector marks data, the `v` field may be removed. +`*storage.Header` is the same structure as `reflect.SliceHeader`, except it stores a `unsafe.Pointer` instead of a `uintptr`. This is done so that eventually when more tests are done to determine how the garbage collector marks data, the `v` field may be removed. The `storage.Header` field of the `array` (and hence `*Dense`) is there to provide a quick and easy way to translate back into a slice for operations that use familiar slice semantics, of which much of the operations are dependent upon. @@ -198,27 +205,29 @@ The above call will use `myEngine` to allocate memory instead. This is useful in ### Other failed designs ### -The alternative designs can be seen in the [ALTERNATIVE DESIGNS document](https://github.com/chewxy/gorgonia/blob/master/tensor/ALTERNATIVEDESIGNS.md) +The alternative designs can be seen in the [ALTERNATIVE DESIGNS document](https://github.com/tensor/blob/master/ALTERNATIVEDESIGNS.md) ## Generic Features ## Example: -```go +```go x := New(WithBacking([]string{"hello", "world", "hello", "world"}), WithShape(2,2)) x = New(WithBacking([]int{1,2,3,4}), WithShape(2,2)) ``` -The above code will not cause a compile error, because the structure holding the underlying array (of `string`s and then of `int`s) is a `*Dense`. +The above code will not cause a compile error, because the structure holding the underlying array (of `string`s and then of `int`s) is a `*Dense`. One could argue that this sidesteps the compiler's type checking system, deferring it to runtime (which a number of people consider dangerous). However, tools are being developed to type check these things, and until Go does support typechecked generics, unfortunately this will be the way it has to be. -Currently, the tensor package supports limited type of genericity - limited to a tensor of any primitive type. +Currently, the tensor package supports limited type of genericity - limited to a tensor of any primitive type. # How This Package is Developed # -Much of the code in this package is generated. The code to generate them is in the directory `genlib2`. +Much of the code in this package is generated. The code to generate them is in the directory `genlib2`. `genlib2` requires [`goimports`](https://godoc.org/golang.org/x/tools/cmd/goimports) binary to be available in the $PATH. +## Tests ## +Tests require python with numpy installed. You can select which python intepreter is being used by setting the environment variable `PYTHON_COMMAND` accordingly. The default value is `python`. ## Things Knowingly Untested For ## - `complex64` and `complex128` are excluded from quick check generation process [Issue #11](https://github.com/gorgonia/tensor/issues/11) @@ -243,7 +252,7 @@ See also: CONTRIBUTING.md ## Contributors and Significant Contributors ## -All contributions are welcome. However, there is a new class of contributor, called Significant Contributors. +All contributions are welcome. However, there is a new class of contributor, called Significant Contributors. A Significant Contributor is one who has shown *deep understanding* of how the library works and/or its environs. Here are examples of what constitutes a Significant Contribution: diff --git a/ap.go b/ap.go index b4b9176..85bf9a7 100644 --- a/ap.go +++ b/ap.go @@ -8,9 +8,12 @@ import ( // An AP is an access pattern. It tells the various ndarrays how to access their data through the use of strides // Through the AP, there are several definitions of things, most notably there are two very specific "special cases": -// Scalar has Dims() of 0. However, its shape can take several forms: -// - (1, 1) +// Scalar has Dims() of 0. // - (1) +// Scalarlikes are higher order tensors, but each with a size of 1. The Dims() are not 0. +// - (1, 1) +// - (1, 1, 1) +// - (1, 1, 1, 1), etc // Vector has Dims() of 1, but its shape can take several forms: // - (x, 1) // - (1, x) @@ -26,13 +29,30 @@ type AP struct { Δ Triangle } -// NewAP creates a new AP, given the shape and strides -func NewAP(shape Shape, strides []int) *AP { - ap := borrowAP() +func makeAP(size int) AP { + return AP{ + shape: Shape(BorrowInts(size)), + strides: BorrowInts(size), + } +} + +// MakeAP creates an AP, given the shape and strides. +func MakeAP(shape Shape, strides []int, o DataOrder, Δ Triangle) AP { + return AP{ + shape: shape, + strides: strides, + o: o, + Δ: Δ, + fin: true, + } +} + +// Init initializes an already created AP with a shape and stries. +// It will panic if AP is nil. +func (ap *AP) Init(shape Shape, strides []int) { ap.shape = shape ap.strides = strides ap.fin = true - return ap } // SetShape is for very specific times when modifying the AP is necessary, such as reshaping and doing I/O related stuff @@ -46,6 +66,9 @@ func (ap *AP) SetShape(s ...int) { if !ap.fin { // scalars are a special case, we don't want to remove it completely if len(s) == 0 { + if ap.shape == nil || ap.strides == nil { + ap.shape = Shape{} + } ap.shape = ap.shape[:0] ap.strides = ap.strides[:0] return @@ -90,21 +113,74 @@ func (ap *AP) Format(state fmt.State, c rune) { // row vector func (ap *AP) IsVector() bool { return ap.shape.IsVector() } +// IsVectorLike returns true if the shape is vector-like (i.e. the shape only has one dim that is a non-1). +func (ap *AP) IsVectorLike() bool { + return ap.shape.IsVectorLike() && allones(ap.strides) +} + // IsColVec returns true when the access pattern has the shape (x, 1) func (ap *AP) IsColVec() bool { return ap.shape.IsColVec() } // IsRowVec returns true when the access pattern has the shape (1, x) func (ap *AP) IsRowVec() bool { return ap.shape.IsRowVec() } -// IsScalar returns true if the access pattern indicates it's a scalar value +// IsScalar returns true if the access pattern indicates it's a scalar value. func (ap *AP) IsScalar() bool { return ap.shape.IsScalar() } +// IsScalarEquiv returns true if the access pattern is equivalent to a scalar shape. +func (ap *AP) IsScalarEquiv() bool { return ap.shape.IsScalarEquiv() } + // IsMatrix returns true if it's a matrix. This is mostly a convenience method. RowVec and ColVecs are also considered matrices func (ap *AP) IsMatrix() bool { return len(ap.shape) == 2 } -// Clone clones the *AP. Clearly. -func (ap *AP) Clone() (retVal *AP) { - retVal = BorrowAP(len(ap.shape)) +// IsZero tell us if the ap has zero size +func (ap *AP) IsZero() bool { + return len(ap.shape) == 0 && len(ap.strides) == 0 && !ap.fin && ap.o == 0 && ap.Δ == 0 +} + +// Zero zeros out an AP. +func (ap *AP) zero() { + // log.Printf("ZEROING. Called by %v", string(debug.Stack())) + + // Jorge's original implementation for zeroing a AP is as below + // but to cater for the (*Dense).fix() method of the *Dense + // a nil shape is used to signal unsetness + // so we cannot just truncate the shape even though it would be a lot more efficient + + // ap.shape = ap.shape[:0] + // ap.strides = ap.strides[:0] + ReturnInts([]int(ap.shape)) + ReturnInts(ap.strides) + ap.zeroOnly() +} + +// side effect free zeroing +func (ap *AP) zeroOnly() { + ap.shape = nil + ap.strides = nil + + ap.fin = false + ap.o = 0 + ap.Δ = 0 +} + +func (ap *AP) zeroWithDims(dims int) { + //ap.shape = BorrowInts(dims) + //ap.strides = BorrowInts(dims) + if cap(ap.shape) >= dims { + ap.shape = ap.shape[:dims] + } + ap.shape = BorrowInts(dims) + if cap(ap.strides) >= dims { + ap.strides = ap.strides[:dims] + } + ap.strides = BorrowInts(dims) +} + +// Clone clones the *AP. Clearly. It returns AP +func (ap *AP) Clone() (retVal AP) { + retVal = makeAP(cap(ap.shape)) + copy(retVal.shape, ap.shape) copy(retVal.strides, ap.strides) @@ -118,21 +194,25 @@ func (ap *AP) Clone() (retVal *AP) { return } +func (ap *AP) CloneTo(dest *AP) { + dest.shape = append(dest.shape[:0], ap.shape...) + dest.strides = append(dest.strides[:0], ap.strides...) + dest.fin = ap.fin + dest.o = ap.o + dest.Δ = ap.Δ +} + // DataOrder returns the data order of the AP. func (ap *AP) DataOrder() DataOrder { return ap.o } // C returns true if the access pattern is C-contiguous array -func (ap *AP) C() bool { - return ap.o.isRowMajor() && ap.o.isContiguous() -} +func (ap *AP) C() bool { return ap.o.IsRowMajor() && ap.o.IsContiguous() } // F returns true if the access pattern is Fortran contiguous array -func (ap *AP) F() bool { - return ap.o.isColMajor() && ap.o.isContiguous() -} +func (ap *AP) F() bool { return ap.o.IsColMajor() && ap.o.IsContiguous() } // S returns the metadata of the sliced tensor. -func (ap *AP) S(size int, slices ...Slice) (newAP *AP, ndStart, ndEnd int, err error) { +func (ap *AP) S(size int, slices ...Slice) (newAP AP, ndStart, ndEnd int, err error) { if len(slices) > len(ap.shape) { // error err = errors.Errorf(dimMismatch, len(ap.shape), len(slices)) @@ -146,7 +226,7 @@ func (ap *AP) S(size int, slices ...Slice) (newAP *AP, ndStart, ndEnd int, err e var outerDim int order := ap.o - if ap.o.isRowMajor() || ap.IsVector() { + if ap.o.IsRowMajor() || ap.IsVector() { outerDim = 0 } else { outerDim = len(ap.shape) - 1 @@ -160,12 +240,13 @@ func (ap *AP) S(size int, slices ...Slice) (newAP *AP, ndStart, ndEnd int, err e size := ap.shape[i] var stride int - if ap.IsVector() { - // handles non-vanilla vectors - stride = ap.strides[0] - } else { - stride = ap.strides[i] - } + stride = ap.strides[i] + // if ap.IsVector() { + // // handles non-vanilla vectors + // stride = ap.strides[0] + // } else { + // stride = ap.strides[i] + // } var start, end, step int if start, end, step, err = SliceDetails(sl, size); err != nil { @@ -176,8 +257,11 @@ func (ap *AP) S(size int, slices ...Slice) (newAP *AP, ndStart, ndEnd int, err e // a slice where start == end is [] ndStart = ndStart + start*stride ndEnd = ndEnd - (size-end)*stride + if step > 0 { - newShape[i] = (end - start) / step + if newShape[i] = (end - start) / step; (end-start)%step > 0 && i > 0 { + newShape[i]++ + } newStrides[i] = stride * step //fix @@ -196,37 +280,31 @@ func (ap *AP) S(size int, slices ...Slice) (newAP *AP, ndStart, ndEnd int, err e if ndEnd-ndStart == 1 { // scalars are a special case - newAP = borrowAP() + newAP = AP{} newAP.SetShape() // make it a Scalar newAP.lock() } else { // drop any dimension with size 1, except the last dimension + offset := 0 for d := 0; d < dims; d++ { - if newShape[d] == 1 /*&& d != t.dims-1 && dims > 2*/ { + if newShape[d] == 1 && offset+d <= len(slices)-1 && slices[offset+d] != nil /*&& d != t.dims-1 && dims > 2*/ { newShape = append(newShape[:d], newShape[d+1:]...) newStrides = append(newStrides[:d], newStrides[d+1:]...) d-- dims-- + offset++ } } - //fix up strides - if newShape.IsColVec() { - stride0 := newStrides[0] - ReturnInts(newStrides) - newStrides = BorrowInts(1) - newStrides[0] = stride0 - } - - newAP = NewAP(newShape, newStrides) - newAP.o = order + newAP = MakeAP(newShape, newStrides, order, ap.Δ) } return } // T returns the transposed metadata based on the given input -func (ap *AP) T(axes ...int) (retVal *AP, a []int, err error) { +func (ap *AP) T(axes ...int) (retVal AP, a []int, err error) { + // prep axes if len(axes) > 0 && len(axes) != ap.Dims() { err = errors.Errorf(dimMismatch, ap.Dims(), len(axes)) @@ -242,9 +320,13 @@ func (ap *AP) T(axes ...int) (retVal *AP, a []int, err error) { } a = axes + if ap.shape.IsScalarEquiv() { + return ap.Clone(), a, noopError{} + } + // if axes is 0, 1, 2, 3... then no op if monotonic, incr1 := IsMonotonicInts(axes); monotonic && incr1 && axes[0] == 0 { - return ap, a, noopError{} + return ap.Clone(), a, noopError{} } currentShape := ap.shape @@ -259,7 +341,7 @@ func (ap *AP) T(axes ...int) (retVal *AP, a []int, err error) { if axes[0] == 0 { return } - copy(strides, currentStride) + strides[0], strides[1] = 1, 1 shape[0], shape[1] = currentShape[1], currentShape[0] default: copy(shape, currentShape) @@ -270,12 +352,8 @@ func (ap *AP) T(axes ...int) (retVal *AP, a []int, err error) { } } - retVal = borrowAP() - retVal.shape = shape - retVal.strides = strides - if ap.IsVector() { - retVal.strides = retVal.strides[:1] - } + o := MakeDataOrder(ap.o, Transposed) + retVal = MakeAP(shape, strides, o, ap.Δ) retVal.fin = true return } @@ -286,14 +364,21 @@ func (ap *AP) unlock() { ap.fin = false } func (ap *AP) calcStrides() []int { switch { - case ap.o.isRowMajor(): - return ap.shape.calcStrides() - case ap.o.isColMajor(): - return ap.shape.calcStridesColMajor() + case ap.o.IsRowMajor(): + return CalcStrides(ap.shape) + case ap.o.IsColMajor(): + return CalcStridesColMajor(ap.shape) } panic("unreachable") } +// setDataOrder is a method such that any tensor that embeds *AP will have the same method +func (ap *AP) setDataOrder(o DataOrder) { + if !o.HasSameOrder(ap.o) { + ap.o = ap.o.toggleColMajor() + } +} + // TransposeIndex returns the new index given the old index func TransposeIndex(i int, oldShape, pattern, oldStrides, newStrides []int) int { oldCoord, err := Itol(i, oldShape, oldStrides) diff --git a/ap_test.go b/ap_test.go index 37e0a28..3a0e5bb 100644 --- a/ap_test.go +++ b/ap_test.go @@ -7,71 +7,40 @@ import ( "github.com/stretchr/testify/assert" ) -type dummySlice struct { - start, end, step int -} - -func (s dummySlice) Start() int { return s.start } -func (s dummySlice) End() int { return s.end } -func (s dummySlice) Step() int { return s.step } - -func sli(start int, opt ...int) dummySlice { - var end, step int - switch len(opt) { - case 0: - end = start + 1 - step = 0 - case 1: - end = opt[0] - step = 1 - default: - end = opt[0] - step = opt[1] +func dummyScalar1() AP { return AP{} } - } - return dummySlice{start: start, end: end, step: step} -} +func dummyScalar2() AP { return AP{shape: Shape{1}} } -func dummyScalar1() *AP { - return &AP{} -} - -func dummyScalar2() *AP { - return &AP{ - shape: Shape{1}, - } -} - -func dummyColVec() *AP { - return &AP{ +func dummyColVec() AP { + return AP{ shape: Shape{5, 1}, strides: []int{1}, } } -func dummyRowVec() *AP { - return &AP{ +func dummyRowVec() AP { + return AP{ shape: Shape{1, 5}, strides: []int{1}, } } -func dummyVec() *AP { - return &AP{ +func dummyVec() AP { + return AP{ shape: Shape{5}, strides: []int{1}, } } -func twothree() *AP { - return &AP{ +func twothree() AP { + return AP{ shape: Shape{2, 3}, strides: []int{3, 1}, } } -func twothreefour() *AP { - return &AP{ +func twothreefour() AP { + return AP{ shape: Shape{2, 3, 4}, strides: []int{12, 4, 1}, } @@ -83,7 +52,7 @@ func TestAccessPatternBasics(t *testing.T) { ap.SetShape(1, 2) assert.Equal(Shape{1, 2}, ap.Shape()) - assert.Equal([]int{1}, ap.Strides()) + assert.Equal([]int{2, 1}, ap.Strides()) assert.Equal(2, ap.Dims()) assert.Equal(2, ap.Size()) @@ -100,31 +69,34 @@ func TestAccessPatternBasics(t *testing.T) { ap.unlock() ap.SetShape(1, 2) assert.Equal(Shape{1, 2}, ap.Shape()) - assert.Equal([]int{1}, ap.Strides()) + assert.Equal([]int{2, 1}, ap.Strides()) assert.Equal(2, ap.Dims()) assert.Equal(2, ap.Size()) - if ap.String() != "Shape: (1, 2), Stride: [1], Lock: false" { - t.Error("AP formatting error. Got %q", ap.String()) + if ap.String() != "Shape: (1, 2), Stride: [2 1], Lock: false" { + t.Errorf("AP formatting error. Got %q", ap.String()) } ap2 := ap.Clone() - assert.Equal(ap, ap2) + assert.Equal(*ap, ap2) } func TestAccessPatternIsX(t *testing.T) { assert := assert.New(t) - var ap *AP + var ap AP ap = dummyScalar1() assert.True(ap.IsScalar()) + assert.True(ap.IsScalarEquiv()) assert.False(ap.IsVector()) assert.False(ap.IsColVec()) assert.False(ap.IsRowVec()) ap = dummyScalar2() - assert.True(ap.IsScalar()) - assert.False(ap.IsVector()) + assert.False(ap.IsScalar()) + assert.True(ap.IsScalarEquiv()) + assert.True(ap.IsVectorLike()) + assert.True(ap.IsVector()) assert.False(ap.IsColVec()) assert.False(ap.IsRowVec()) @@ -151,7 +123,7 @@ func TestAccessPatternIsX(t *testing.T) { func TestAccessPatternT(t *testing.T) { assert := assert.New(t) - var ap, apT *AP + var ap, apT AP var axes []int var err error @@ -206,26 +178,32 @@ var sliceTests = []struct { contiguous bool }{ // vectors - {"a[0]", Shape{5}, []Slice{sli(0)}, 0, 1, ScalarShape(), nil, true}, - {"a[0:2]", Shape{5}, []Slice{sli(0, 2)}, 0, 2, Shape{2}, []int{1}, true}, - {"a[1:3]", Shape{5}, []Slice{sli(1, 3)}, 1, 3, Shape{2}, []int{1}, true}, - {"a[1:5:2]", Shape{5}, []Slice{sli(1, 5, 2)}, 1, 5, Shape{2}, []int{2}, false}, + {"a[0]", Shape{5}, []Slice{S(0)}, 0, 1, ScalarShape(), nil, true}, + {"a[0:2]", Shape{5}, []Slice{S(0, 2)}, 0, 2, Shape{2}, []int{1}, true}, + {"a[1:3]", Shape{5}, []Slice{S(1, 3)}, 1, 3, Shape{2}, []int{1}, true}, + {"a[1:5:2]", Shape{5}, []Slice{S(1, 5, 2)}, 1, 5, Shape{2}, []int{2}, false}, // matrix - {"A[0]", Shape{2, 3}, []Slice{sli(0)}, 0, 3, Shape{1, 3}, []int{1}, true}, - {"A[1:3]", Shape{4, 5}, []Slice{sli(1, 3)}, 5, 15, Shape{2, 5}, []int{5, 1}, true}, - {"A[0:10] (intentionally over)", Shape{4, 5}, []Slice{sli(0, 10)}, 0, 20, Shape{4, 5}, []int{5, 1}, true}, // as if nothing happened - {"A[:, 1:3]", Shape{4, 5}, []Slice{nil, sli(1, 3)}, 1, 18, Shape{4, 2}, []int{5, 1}, false}, + {"A[0]", Shape{2, 3}, []Slice{S(0)}, 0, 3, Shape{1, 3}, []int{1}, true}, + {"A[1:3]", Shape{4, 5}, []Slice{S(1, 3)}, 5, 15, Shape{2, 5}, []int{5, 1}, true}, + {"A[0:10] (intentionally over)", Shape{4, 5}, []Slice{S(0, 10)}, 0, 20, Shape{4, 5}, []int{5, 1}, true}, // as if nothing happened + {"A[:, 1:3]", Shape{4, 5}, []Slice{nil, S(1, 3)}, 1, 18, Shape{4, 2}, []int{5, 1}, false}, + + // tensor + {"tensor[0, :, :]", Shape{1, 2, 2}, []Slice{rs{0, 1, 1}, nil, nil}, 0, 4, Shape{2, 2}, []int{2, 1}, true}, + {"tensor[:, 0, :]", Shape{1, 2, 2}, []Slice{nil, rs{0, 1, 1}, nil}, 0, 2, Shape{1, 2}, []int{4, 1}, false}, + {"tensor[0, :, :, :]", Shape{1, 1, 2, 2}, []Slice{rs{0, 1, 1}, nil, nil, nil}, 0, 4, Shape{1, 2, 2}, []int{4, 2, 1}, true}, + {"tensor[0,]", Shape{1, 1, 2, 2}, []Slice{rs{0, 1, 1}}, 0, 4, Shape{1, 2, 2}, []int{4, 2, 1}, true}, } func TestAccessPatternS(t *testing.T) { assert := assert.New(t) - var ap, apS *AP + var ap, apS AP var ndStart, ndEnd int var err error for _, sts := range sliceTests { - ap = NewAP(sts.shape, sts.shape.calcStrides()) + ap = MakeAP(sts.shape, CalcStrides(sts.shape), 0, 0) if apS, ndStart, ndEnd, err = ap.S(sts.shape.TotalSize(), sts.slices...); err != nil { t.Errorf("%v errored: %v", sts.name, err) continue @@ -234,7 +212,7 @@ func TestAccessPatternS(t *testing.T) { assert.Equal(sts.correctEnd, ndEnd, "Wrong end: %v. Want %d Got %d", sts.name, sts.correctEnd, ndEnd) assert.True(sts.correctShape.Eq(apS.shape), "Wrong shape: %v. Want %v. Got %v", sts.name, sts.correctShape, apS.shape) assert.Equal(sts.correctStride, apS.strides, "Wrong strides: %v. Want %v. Got %v", sts.name, sts.correctStride, apS.strides) - assert.Equal(sts.contiguous, apS.DataOrder().isContiguous(), "Wrong contiguity for %v Want %t.", sts.name, sts.contiguous) + assert.Equal(sts.contiguous, apS.DataOrder().IsContiguous(), "Wrong contiguity for %v Want %t.", sts.name, sts.contiguous) } } diff --git a/api_arith.go b/api_arith.go index b70dd9b..13ccd05 100644 --- a/api_arith.go +++ b/api_arith.go @@ -19,26 +19,53 @@ import ( // If the Unsafe flag is passed in, the data of the first tensor will be overwritten func Add(a, b interface{}, opts ...FuncOpt) (retVal Tensor, err error) { var adder Adder - var oe standardEngine + var oe StandardEngine var ok bool switch at := a.(type) { case Tensor: - oe = at.standardEngine() + oe, _ = at.Engine().(StandardEngine) switch bt := b.(type) { case Tensor: - if oe != nil { - return oe.Add(at, bt, opts...) - } - if oe = bt.standardEngine(); oe != nil { - return oe.Add(at, bt, opts...) - } - if adder, ok = at.Engine().(Adder); ok { - return adder.Add(at, bt, opts...) - } - if adder, ok = bt.Engine().(Adder); ok { - return adder.Add(at, bt, opts...) + if !bt.Shape().IsScalar() && !at.Shape().IsScalar() { // non-scalar Tensor addition + if oe != nil { + return oe.Add(at, bt, opts...) + } + if oe, ok = bt.Engine().(StandardEngine); ok { + return oe.Add(at, bt, opts...) + } + if adder, ok = at.Engine().(Adder); ok { + return adder.Add(at, bt, opts...) + } + if adder, ok = bt.Engine().(Adder); ok { + return adder.Add(at, bt, opts...) + } + return nil, errors.New("Neither engines of either operand support Add") + + } else { // at least one of the operands is a scalar + var leftTensor bool + if !bt.Shape().IsScalar() { + leftTensor = false // a Scalar-Tensor * b Tensor + tmp := at + at = bt + bt = tmp + } else { + leftTensor = true // a Tensor * b Scalar-Tensor + } + + if oe != nil { + return oe.AddScalar(at, bt, leftTensor, opts...) + } + if oe, ok = bt.Engine().(StandardEngine); ok { + return oe.AddScalar(at, bt, leftTensor, opts...) + } + if adder, ok = at.Engine().(Adder); ok { + return adder.AddScalar(at, bt, leftTensor, opts...) + } + if adder, ok = bt.Engine().(Adder); ok { + return adder.AddScalar(at, bt, leftTensor, opts...) + } + return nil, errors.New("Neither engines of either operand support Add") } - return nil, errors.New("Neither engines of either operand support Add") default: if oe != nil { @@ -52,7 +79,7 @@ func Add(a, b interface{}, opts ...FuncOpt) (retVal Tensor, err error) { default: switch bt := b.(type) { case Tensor: - if oe = bt.standardEngine(); oe != nil { + if oe, ok = bt.Engine().(StandardEngine); ok { return oe.AddScalar(bt, at, false, opts...) } if adder, ok = bt.Engine().(Adder); ok { @@ -73,26 +100,53 @@ func Add(a, b interface{}, opts ...FuncOpt) (retVal Tensor, err error) { // If the Unsafe flag is passed in, the data of the first tensor will be overwritten func Sub(a, b interface{}, opts ...FuncOpt) (retVal Tensor, err error) { var suber Suber - var oe standardEngine + var oe StandardEngine var ok bool switch at := a.(type) { case Tensor: - oe = at.standardEngine() + oe, _ = at.Engine().(StandardEngine) switch bt := b.(type) { case Tensor: - if oe != nil { - return oe.Sub(at, bt, opts...) - } - if oe = bt.standardEngine(); oe != nil { - return oe.Sub(at, bt, opts...) - } - if suber, ok = at.Engine().(Suber); ok { - return suber.Sub(at, bt, opts...) - } - if suber, ok = bt.Engine().(Suber); ok { - return suber.Sub(at, bt, opts...) + if !bt.Shape().IsScalar() && !at.Shape().IsScalar() { // non-scalar Tensor substraction + if oe != nil { + return oe.Sub(at, bt, opts...) + } + if oe, ok = bt.Engine().(StandardEngine); ok { + return oe.Sub(at, bt, opts...) + } + if suber, ok = at.Engine().(Suber); ok { + return suber.Sub(at, bt, opts...) + } + if suber, ok = bt.Engine().(Suber); ok { + return suber.Sub(at, bt, opts...) + } + return nil, errors.New("Neither engines of either operand support Sub") + + } else { // at least one of the operands is a scalar + var leftTensor bool + if !bt.Shape().IsScalar() { + leftTensor = false // a Scalar-Tensor * b Tensor + tmp := at + at = bt + bt = tmp + } else { + leftTensor = true // a Tensor * b Scalar-Tensor + } + + if oe != nil { + return oe.SubScalar(at, bt, leftTensor, opts...) + } + if oe, ok = bt.Engine().(StandardEngine); ok { + return oe.SubScalar(at, bt, leftTensor, opts...) + } + if suber, ok = at.Engine().(Suber); ok { + return suber.SubScalar(at, bt, leftTensor, opts...) + } + if suber, ok = bt.Engine().(Suber); ok { + return suber.SubScalar(at, bt, leftTensor, opts...) + } + return nil, errors.New("Neither engines of either operand support Sub") } - return nil, errors.New("Neither engines of either operand support Sub") default: if oe != nil { @@ -106,7 +160,7 @@ func Sub(a, b interface{}, opts ...FuncOpt) (retVal Tensor, err error) { default: switch bt := b.(type) { case Tensor: - if oe = bt.standardEngine(); oe != nil { + if oe, ok = bt.Engine().(StandardEngine); ok { return oe.SubScalar(bt, at, false, opts...) } if suber, ok = bt.Engine().(Suber); ok { @@ -127,18 +181,18 @@ func Sub(a, b interface{}, opts ...FuncOpt) (retVal Tensor, err error) { // If the Unsafe flag is passed in, the data of the first tensor will be overwritten func Mul(a, b interface{}, opts ...FuncOpt) (retVal Tensor, err error) { var muler Muler - var oe standardEngine + var oe StandardEngine var ok bool switch at := a.(type) { case Tensor: - oe = at.standardEngine() + oe, _ = at.Engine().(StandardEngine) switch bt := b.(type) { case Tensor: if !bt.Shape().IsScalar() && !at.Shape().IsScalar() { // non-scalar Tensor multiplication if oe != nil { return oe.Mul(at, bt, opts...) } - if oe = bt.standardEngine(); oe != nil { + if oe, ok = bt.Engine().(StandardEngine); ok { return oe.Mul(at, bt, opts...) } if muler, ok = at.Engine().(Muler); ok { @@ -149,10 +203,13 @@ func Mul(a, b interface{}, opts ...FuncOpt) (retVal Tensor, err error) { } return nil, errors.New("Neither engines of either operand support Mul") - } else { // one of the operands is a scalar + } else { // at least one of the operands is a scalar var leftTensor bool - if at.Shape().IsScalar() { + if !bt.Shape().IsScalar() { leftTensor = false // a Scalar-Tensor * b Tensor + tmp := at + at = bt + bt = tmp } else { leftTensor = true // a Tensor * b Scalar-Tensor } @@ -160,7 +217,7 @@ func Mul(a, b interface{}, opts ...FuncOpt) (retVal Tensor, err error) { if oe != nil { return oe.MulScalar(at, bt, leftTensor, opts...) } - if oe = bt.standardEngine(); oe != nil { + if oe, ok = bt.Engine().(StandardEngine); ok { return oe.MulScalar(at, bt, leftTensor, opts...) } if muler, ok = at.Engine().(Muler); ok { @@ -185,7 +242,7 @@ func Mul(a, b interface{}, opts ...FuncOpt) (retVal Tensor, err error) { default: switch bt := b.(type) { case Tensor: // b Tensor * a interface - if oe = bt.standardEngine(); oe != nil { + if oe, ok = bt.Engine().(StandardEngine); ok { return oe.MulScalar(bt, at, false, opts...) } if muler, ok = bt.Engine().(Muler); ok { @@ -207,26 +264,53 @@ func Mul(a, b interface{}, opts ...FuncOpt) (retVal Tensor, err error) { // If the Unsafe flag is passed in, the data of the first tensor will be overwritten func Div(a, b interface{}, opts ...FuncOpt) (retVal Tensor, err error) { var diver Diver - var oe standardEngine + var oe StandardEngine var ok bool switch at := a.(type) { case Tensor: - oe = at.standardEngine() + oe, _ = at.Engine().(StandardEngine) switch bt := b.(type) { case Tensor: - if oe != nil { - return oe.Div(at, bt, opts...) - } - if oe = bt.standardEngine(); oe != nil { - return oe.Div(at, bt, opts...) - } - if diver, ok = at.Engine().(Diver); ok { - return diver.Div(at, bt, opts...) - } - if diver, ok = bt.Engine().(Diver); ok { - return diver.Div(at, bt, opts...) + if !bt.Shape().IsScalar() && !at.Shape().IsScalar() { // non-scalar Tensor division + if oe != nil { + return oe.Div(at, bt, opts...) + } + if oe, ok = bt.Engine().(StandardEngine); ok { + return oe.Div(at, bt, opts...) + } + if diver, ok = at.Engine().(Diver); ok { + return diver.Div(at, bt, opts...) + } + if diver, ok = bt.Engine().(Diver); ok { + return diver.Div(at, bt, opts...) + } + return nil, errors.New("Neither engines of either operand support Div") + + } else { // at least one of the operands is a scalar + var leftTensor bool + if !bt.Shape().IsScalar() { + leftTensor = false // a Scalar-Tensor * b Tensor + tmp := at + at = bt + bt = tmp + } else { + leftTensor = true // a Tensor * b Scalar-Tensor + } + + if oe != nil { + return oe.DivScalar(at, bt, leftTensor, opts...) + } + if oe, ok = bt.Engine().(StandardEngine); ok { + return oe.DivScalar(at, bt, leftTensor, opts...) + } + if diver, ok = at.Engine().(Diver); ok { + return diver.DivScalar(at, bt, leftTensor, opts...) + } + if diver, ok = bt.Engine().(Diver); ok { + return diver.DivScalar(at, bt, leftTensor, opts...) + } + return nil, errors.New("Neither engines of either operand support Div") } - return nil, errors.New("Neither engines of either operand support Div") default: if oe != nil { @@ -240,7 +324,7 @@ func Div(a, b interface{}, opts ...FuncOpt) (retVal Tensor, err error) { default: switch bt := b.(type) { case Tensor: - if oe = bt.standardEngine(); oe != nil { + if oe, ok = bt.Engine().(StandardEngine); ok { return oe.DivScalar(bt, at, false, opts...) } if diver, ok = bt.Engine().(Diver); ok { @@ -261,26 +345,53 @@ func Div(a, b interface{}, opts ...FuncOpt) (retVal Tensor, err error) { // If the Unsafe flag is passed in, the data of the first tensor will be overwritten func Pow(a, b interface{}, opts ...FuncOpt) (retVal Tensor, err error) { var power Power - var oe standardEngine + var oe StandardEngine var ok bool switch at := a.(type) { case Tensor: - oe = at.standardEngine() + oe, _ = at.Engine().(StandardEngine) switch bt := b.(type) { case Tensor: - if oe != nil { - return oe.Pow(at, bt, opts...) - } - if oe = bt.standardEngine(); oe != nil { - return oe.Pow(at, bt, opts...) - } - if power, ok = at.Engine().(Power); ok { - return power.Pow(at, bt, opts...) - } - if power, ok = bt.Engine().(Power); ok { - return power.Pow(at, bt, opts...) + if !bt.Shape().IsScalar() && !at.Shape().IsScalar() { // non-scalar Tensor exponentiation + if oe != nil { + return oe.Pow(at, bt, opts...) + } + if oe, ok = bt.Engine().(StandardEngine); ok { + return oe.Pow(at, bt, opts...) + } + if power, ok = at.Engine().(Power); ok { + return power.Pow(at, bt, opts...) + } + if power, ok = bt.Engine().(Power); ok { + return power.Pow(at, bt, opts...) + } + return nil, errors.New("Neither engines of either operand support Pow") + + } else { // at least one of the operands is a scalar + var leftTensor bool + if !bt.Shape().IsScalar() { + leftTensor = false // a Scalar-Tensor * b Tensor + tmp := at + at = bt + bt = tmp + } else { + leftTensor = true // a Tensor * b Scalar-Tensor + } + + if oe != nil { + return oe.PowScalar(at, bt, leftTensor, opts...) + } + if oe, ok = bt.Engine().(StandardEngine); ok { + return oe.PowScalar(at, bt, leftTensor, opts...) + } + if power, ok = at.Engine().(Power); ok { + return power.PowScalar(at, bt, leftTensor, opts...) + } + if power, ok = bt.Engine().(Power); ok { + return power.PowScalar(at, bt, leftTensor, opts...) + } + return nil, errors.New("Neither engines of either operand support Pow") } - return nil, errors.New("Neither engines of either operand support Pow") default: if oe != nil { @@ -294,7 +405,7 @@ func Pow(a, b interface{}, opts ...FuncOpt) (retVal Tensor, err error) { default: switch bt := b.(type) { case Tensor: - if oe = bt.standardEngine(); oe != nil { + if oe, ok = bt.Engine().(StandardEngine); ok { return oe.PowScalar(bt, at, false, opts...) } if power, ok = bt.Engine().(Power); ok { @@ -308,33 +419,60 @@ func Pow(a, b interface{}, opts ...FuncOpt) (retVal Tensor, err error) { panic("Unreachable") } -// Mod performs elementwise exponentiation on the Tensor(s). These operations are supported: +// Mod performs elementwise modulo on the Tensor(s). These operations are supported: // Mod(*Dense, scalar) // Mod(scalar, *Dense) // Mod(*Dense, *Dense) // If the Unsafe flag is passed in, the data of the first tensor will be overwritten func Mod(a, b interface{}, opts ...FuncOpt) (retVal Tensor, err error) { var moder Moder - var oe standardEngine + var oe StandardEngine var ok bool switch at := a.(type) { case Tensor: - oe = at.standardEngine() + oe, _ = at.Engine().(StandardEngine) switch bt := b.(type) { case Tensor: - if oe != nil { - return oe.Mod(at, bt, opts...) - } - if oe = bt.standardEngine(); oe != nil { - return oe.Mod(at, bt, opts...) - } - if moder, ok = at.Engine().(Moder); ok { - return moder.Mod(at, bt, opts...) - } - if moder, ok = bt.Engine().(Moder); ok { - return moder.Mod(at, bt, opts...) + if !bt.Shape().IsScalar() && !at.Shape().IsScalar() { // non-scalar Tensor modulo + if oe != nil { + return oe.Mod(at, bt, opts...) + } + if oe, ok = bt.Engine().(StandardEngine); ok { + return oe.Mod(at, bt, opts...) + } + if moder, ok = at.Engine().(Moder); ok { + return moder.Mod(at, bt, opts...) + } + if moder, ok = bt.Engine().(Moder); ok { + return moder.Mod(at, bt, opts...) + } + return nil, errors.New("Neither engines of either operand support Mod") + + } else { // at least one of the operands is a scalar + var leftTensor bool + if !bt.Shape().IsScalar() { + leftTensor = false // a Scalar-Tensor * b Tensor + tmp := at + at = bt + bt = tmp + } else { + leftTensor = true // a Tensor * b Scalar-Tensor + } + + if oe != nil { + return oe.ModScalar(at, bt, leftTensor, opts...) + } + if oe, ok = bt.Engine().(StandardEngine); ok { + return oe.ModScalar(at, bt, leftTensor, opts...) + } + if moder, ok = at.Engine().(Moder); ok { + return moder.ModScalar(at, bt, leftTensor, opts...) + } + if moder, ok = bt.Engine().(Moder); ok { + return moder.ModScalar(at, bt, leftTensor, opts...) + } + return nil, errors.New("Neither engines of either operand support Mod") } - return nil, errors.New("Neither engines of either operand support Mod") default: if oe != nil { @@ -348,7 +486,7 @@ func Mod(a, b interface{}, opts ...FuncOpt) (retVal Tensor, err error) { default: switch bt := b.(type) { case Tensor: - if oe = bt.standardEngine(); oe != nil { + if oe, ok = bt.Engine().(StandardEngine); ok { return oe.ModScalar(bt, at, false, opts...) } if moder, ok = bt.Engine().(Moder); ok { @@ -388,41 +526,28 @@ func Dot(x, y Tensor, opts ...FuncOpt) (retVal Tensor, err error) { // FMA performs Y = A * X + Y. func FMA(a Tensor, x interface{}, y Tensor) (retVal Tensor, err error) { - if xTensor, ok := x.(Tensor); ok { - if oe := a.standardEngine(); oe != nil { - return oe.FMA(a, xTensor, y) - } - if oe := xTensor.standardEngine(); oe != nil { - return oe.FMA(a, xTensor, y) - } - if oe := y.standardEngine(); oe != nil { - return oe.FMA(a, xTensor, y) - } + var fm FMAer - if e, ok := a.Engine().(FMAer); ok { - return e.FMA(a, xTensor, y) - } - if e, ok := xTensor.Engine().(FMAer); ok { - return e.FMA(a, xTensor, y) - } - if e, ok := y.Engine().(FMAer); ok { - return e.FMA(a, xTensor, y) + if xTensor, ok := x.(Tensor); ok { + for _, T := range [3]Tensor{a, xTensor, y} { + e := T.Engine() + ctx := ctxFromEngine(e) + fm, ok = e.(FMAer) + if ok { + return fm.FMA(ctx, a, xTensor, y) + } } } else { - if oe := a.standardEngine(); oe != nil { - return oe.FMAScalar(a, x, y) - } - if oe := y.standardEngine(); oe != nil { - return oe.FMAScalar(a, x, y) - } - - if e, ok := a.Engine().(FMAer); ok { - return e.FMAScalar(a, x, y) - } - if e, ok := y.Engine().(FMAer); ok { - return e.FMAScalar(a, x, y) + for _, T := range [2]Tensor{a, y} { + e := T.Engine() + ctx := ctxFromEngine(e) + fm, ok = e.(FMAer) + if ok { + return fm.FMAScalar(ctx, a, x, y) + } } } + return Mul(a, x, WithIncr(y)) } @@ -432,13 +557,66 @@ func MatMul(a, b Tensor, opts ...FuncOpt) (retVal Tensor, err error) { err = errors.Errorf(dtypeMismatch, a.Dtype(), b.Dtype()) return } + ad, aok := a.(*Dense) + _, bok := b.(*Dense) + if aok && bok { + // fast path + return ad.MatMul(b, opts...) + } - switch at := a.(type) { - case *Dense: - bt := b.(*Dense) - return at.MatMul(bt, opts...) + // check that both are matrices + if !a.Shape().IsMatrix() || !b.Shape().IsMatrix() { + err = errors.Errorf("MatMul requires both operands to be matrices. Got t's shape: %v, other's shape: %v", a.Shape(), b.Shape()) + return } - panic("Unreachable") + + // checks that t is mxk matrix + var m, n, k int + m = a.Shape()[0] + k = a.Shape()[1] + n = b.Shape()[1] + + // check shape + if k != b.Shape()[0] { + err = errors.Errorf(shapeMismatch, a.Shape(), b.Shape()) + return + } + + // check whether retVal has the same size as the resulting matrix would be: mxn + expectedShape := Shape{m, n} + + eng := a.Engine() + mm, ok := eng.(MatMuler) + if !ok { + eng = b.Engine() + mm, ok = eng.(MatMuler) + } + if !ok { + return nil, errors.Errorf("Neither a or b have an engine that is a MatMuler. a: %T, b: %T", a.Engine(), b.Engine()) + } + + var reuse Tensor + fo := ParseFuncOpts(opts...) + defer returnOpOpt(fo) + ctx := fo.Context() + reuse = fo.Reuse() + if reuse == nil { + return nil, errors.Errorf("MatMul requires passing in of a reuse Tensor for now.") + } + + if err := checkFixShape(reuse, expectedShape); err != nil { + return nil, errors.Wrapf(err, opFail, "MatMul") + } + if err = mm.MatMul(ctx, a, b, reuse); err != nil { + return nil, errors.Wrapf(err, opFail, "MatMul") + } + + incr := fo.Incr() + if incr != nil { + return Add(incr, reuse, UseUnsafe()) + } + return reuse, nil + } // MatVecMul performs matrix-vector multiplication between two Tensors. `a` is expected to be a matrix, and `b` is expected to be a vector @@ -457,7 +635,7 @@ func MatVecMul(a, b Tensor, opts ...FuncOpt) (retVal Tensor, err error) { } // Inner finds the inner products of two vector Tensors. Both arguments to the functions are eexpected to be vectors. -func Inner(a, b Tensor) (retVal interface{}, err error) { +func Inner(a, b Tensor, opts ...FuncOpt) (retVal interface{}, err error) { if a.Dtype() != b.Dtype() { err = errors.Errorf(dtypeMismatch, a.Dtype(), b.Dtype()) return diff --git a/api_arith_generated_test.go b/api_arith_generated_test.go index 30345d6..f26b7f5 100644 --- a/api_arith_generated_test.go +++ b/api_arith_generated_test.go @@ -1,17 +1,21 @@ -// Code generated by genlib2. DO NOT EDIT. - package tensor import ( + "context" "testing" "testing/quick" + "time" + + "gorgonia.org/dtype" ) +// Code generated by genlib2. DO NOT EDIT. + func TestAdd(t *testing.T) { iden := func(a *Dense) bool { b := New(Of(a.t), WithShape(a.Shape().Clone()...), WithEngine(a.Engine())) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := a.Engine().(Adder) we = we || !ok @@ -37,7 +41,7 @@ func TestSub(t *testing.T) { inv := func(a *Dense) bool { b := New(Of(a.t), WithShape(a.Shape().Clone()...), WithEngine(a.Engine())) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := a.Engine().(Suber) we = we || !ok @@ -64,7 +68,7 @@ func TestMul(t *testing.T) { b := New(Of(a.t), WithShape(a.Shape().Clone()...), WithEngine(a.Engine())) b.Memset(identityVal(1, a.t)) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := a.Engine().(Muler) we = we || !ok @@ -91,7 +95,7 @@ func TestDiv(t *testing.T) { b := New(Of(a.t), WithShape(a.Shape().Clone()...), WithEngine(a.Engine())) b.Memset(identityVal(1, a.t)) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := a.Engine().(Diver) we = we || !ok @@ -118,7 +122,7 @@ func TestPow(t *testing.T) { b := New(Of(a.t), WithShape(a.Shape().Clone()...), WithEngine(a.Engine())) b.Memset(identityVal(1, a.t)) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, floatcmplxTypes, complexTypes) + we, willFailEq := willerr(a, dtype.FloatComplex, dtype.Complexes) _, ok := a.Engine().(Power) we = we || !ok @@ -144,7 +148,7 @@ func TestAdd_unsafe(t *testing.T) { iden := func(a *Dense) bool { b := New(Of(a.t), WithShape(a.Shape().Clone()...), WithEngine(a.Engine())) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := a.Engine().(Adder) we = we || !ok @@ -163,7 +167,6 @@ func TestAdd_unsafe(t *testing.T) { t.Errorf("Expected ret to be the same as a") return false } - return true } if err := quick.Check(iden, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { @@ -175,7 +178,7 @@ func TestSub_unsafe(t *testing.T) { inv := func(a *Dense) bool { b := New(Of(a.t), WithShape(a.Shape().Clone()...), WithEngine(a.Engine())) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := a.Engine().(Suber) we = we || !ok @@ -195,7 +198,6 @@ func TestSub_unsafe(t *testing.T) { t.Errorf("Expected ret to be the same as a") return false } - return true } if err := quick.Check(inv, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { @@ -207,7 +209,7 @@ func TestMul_unsafe(t *testing.T) { b := New(Of(a.t), WithShape(a.Shape().Clone()...), WithEngine(a.Engine())) b.Memset(identityVal(1, a.t)) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := a.Engine().(Muler) we = we || !ok @@ -226,7 +228,6 @@ func TestMul_unsafe(t *testing.T) { t.Errorf("Expected ret to be the same as a") return false } - return true } if err := quick.Check(iden, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { @@ -239,7 +240,7 @@ func TestDiv_unsafe(t *testing.T) { b := New(Of(a.t), WithShape(a.Shape().Clone()...), WithEngine(a.Engine())) b.Memset(identityVal(1, a.t)) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := a.Engine().(Diver) we = we || !ok @@ -259,7 +260,6 @@ func TestDiv_unsafe(t *testing.T) { t.Errorf("Expected ret to be the same as a") return false } - return true } if err := quick.Check(inv, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { @@ -271,7 +271,7 @@ func TestPow_unsafe(t *testing.T) { b := New(Of(a.t), WithShape(a.Shape().Clone()...), WithEngine(a.Engine())) b.Memset(identityVal(1, a.t)) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, floatcmplxTypes, complexTypes) + we, willFailEq := willerr(a, dtype.FloatComplex, dtype.Complexes) _, ok := a.Engine().(Power) we = we || !ok @@ -290,7 +290,6 @@ func TestPow_unsafe(t *testing.T) { t.Errorf("Expected ret to be the same as a") return false } - return true } if err := quick.Check(iden, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { @@ -303,7 +302,7 @@ func TestAdd_reuse(t *testing.T) { b := New(Of(a.t), WithShape(a.Shape().Clone()...), WithEngine(a.Engine())) reuse := New(Of(a.t), WithShape(a.Shape().Clone()...)) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := a.Engine().(Adder) we = we || !ok @@ -335,7 +334,7 @@ func TestSub_reuse(t *testing.T) { b := New(Of(a.t), WithShape(a.Shape().Clone()...), WithEngine(a.Engine())) reuse := New(Of(a.t), WithShape(a.Shape().Clone()...)) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := a.Engine().(Suber) we = we || !ok @@ -368,7 +367,7 @@ func TestMul_reuse(t *testing.T) { b.Memset(identityVal(1, a.t)) reuse := New(Of(a.t), WithShape(a.Shape().Clone()...)) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := a.Engine().(Muler) we = we || !ok @@ -401,7 +400,7 @@ func TestDiv_reuse(t *testing.T) { b.Memset(identityVal(1, a.t)) reuse := New(Of(a.t), WithShape(a.Shape().Clone()...)) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := a.Engine().(Diver) we = we || !ok @@ -434,7 +433,7 @@ func TestPow_reuse(t *testing.T) { b.Memset(identityVal(1, a.t)) reuse := New(Of(a.t), WithShape(a.Shape().Clone()...)) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, floatcmplxTypes, complexTypes) + we, willFailEq := willerr(a, dtype.FloatComplex, dtype.Complexes) _, ok := a.Engine().(Power) we = we || !ok @@ -468,7 +467,7 @@ func TestAdd_incr(t *testing.T) { correct := a.Clone().(*Dense) incr.Memset(identityVal(100, a.t)) correct.Add(incr, UseUnsafe()) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := a.Engine().(Adder) we = we || !ok @@ -497,7 +496,7 @@ func TestSub_incr(t *testing.T) { correct := a.Clone().(*Dense) incr.Memset(identityVal(100, a.t)) correct.Add(incr, UseUnsafe()) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := a.Engine().(Suber) we = we || !ok @@ -527,7 +526,7 @@ func TestMul_incr(t *testing.T) { correct := a.Clone().(*Dense) incr.Memset(identityVal(100, a.t)) correct.Add(incr, UseUnsafe()) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := a.Engine().(Muler) we = we || !ok @@ -557,7 +556,7 @@ func TestDiv_incr(t *testing.T) { correct := a.Clone().(*Dense) incr.Memset(identityVal(100, a.t)) correct.Add(incr, UseUnsafe()) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := a.Engine().(Diver) we = we || !ok @@ -587,7 +586,7 @@ func TestPow_incr(t *testing.T) { correct := a.Clone().(*Dense) incr.Memset(identityVal(100, a.t)) correct.Add(incr, UseUnsafe()) - we, willFailEq := willerr(a, floatcmplxTypes, complexTypes) + we, willFailEq := willerr(a, dtype.FloatComplex, dtype.Complexes) _, ok := a.Engine().(Power) we = we || !ok @@ -608,6 +607,204 @@ func TestPow_incr(t *testing.T) { t.Errorf("Identity test for Pow failed: %v", err) } +} +func TestAdd_context(t *testing.T) { + iden := func(a *Dense) bool { + b := New(Of(a.t), WithShape(a.Shape().Clone()...), WithEngine(a.Engine())) + rng := newRand() + r := rng.Intn(10) + var ctx context.Context + var cancel context.CancelFunc + if r < 5 { + ctx, cancel = context.WithTimeout(context.Background(), 1*time.Microsecond) + } else { + ctx, cancel = context.WithTimeout(context.Background(), time.Duration(r*100)*time.Second) + } + defer cancel() + correct := a.Clone().(*Dense) + we, willFailEq := willerr(a, dtype.Number, nilTC) + _, ok := a.Engine().(Adder) + we = we || !ok + + ret, err := Add(a, b, WithContext(ctx)) + if _, ok := err.(NoOpError); ok && r < 5 { + return true // short circuit + } + if err, retEarly := qcErrCheck(t, "Add", a, b, we, err); retEarly { + if err != nil { + return false + } + return true + } + + if !qcEqCheck(t, a.Dtype(), willFailEq, correct.Data(), ret.Data()) { + return false + } + return true + } + if err := quick.Check(iden, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { + t.Errorf("Identity test for Add failed: %v", err) + } + +} +func TestSub_context(t *testing.T) { + inv := func(a *Dense) bool { + b := New(Of(a.t), WithShape(a.Shape().Clone()...), WithEngine(a.Engine())) + rng := newRand() + r := rng.Intn(10) + var ctx context.Context + var cancel context.CancelFunc + if r < 5 { + ctx, cancel = context.WithTimeout(context.Background(), 1*time.Microsecond) + } else { + ctx, cancel = context.WithTimeout(context.Background(), time.Duration(r*100)*time.Second) + } + defer cancel() + correct := a.Clone().(*Dense) + we, willFailEq := willerr(a, dtype.Number, nilTC) + _, ok := a.Engine().(Suber) + we = we || !ok + + ret, err := Sub(a, b, WithContext(ctx)) + if _, ok := err.(NoOpError); ok && r < 5 { + return true // short circuit + } + if err, retEarly := qcErrCheck(t, "Sub", a, b, we, err); retEarly { + if err != nil { + return false + } + return true + } + ret, err = Add(ret, b, UseUnsafe()) + + if !qcEqCheck(t, a.Dtype(), willFailEq, correct.Data(), ret.Data()) { + return false + } + return true + } + if err := quick.Check(inv, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { + t.Errorf("Inv test for Sub failed: %v", err) + } +} +func TestMul_context(t *testing.T) { + iden := func(a *Dense) bool { + b := New(Of(a.t), WithShape(a.Shape().Clone()...), WithEngine(a.Engine())) + b.Memset(identityVal(1, a.t)) + rng := newRand() + r := rng.Intn(10) + var ctx context.Context + var cancel context.CancelFunc + if r < 5 { + ctx, cancel = context.WithTimeout(context.Background(), 1*time.Microsecond) + } else { + ctx, cancel = context.WithTimeout(context.Background(), time.Duration(r*100)*time.Second) + } + defer cancel() + correct := a.Clone().(*Dense) + we, willFailEq := willerr(a, dtype.Number, nilTC) + _, ok := a.Engine().(Muler) + we = we || !ok + + ret, err := Mul(a, b, WithContext(ctx)) + if _, ok := err.(NoOpError); ok && r < 5 { + return true // short circuit + } + if err, retEarly := qcErrCheck(t, "Mul", a, b, we, err); retEarly { + if err != nil { + return false + } + return true + } + + if !qcEqCheck(t, a.Dtype(), willFailEq, correct.Data(), ret.Data()) { + return false + } + return true + } + if err := quick.Check(iden, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { + t.Errorf("Identity test for Mul failed: %v", err) + } + +} +func TestDiv_context(t *testing.T) { + inv := func(a *Dense) bool { + b := New(Of(a.t), WithShape(a.Shape().Clone()...), WithEngine(a.Engine())) + b.Memset(identityVal(1, a.t)) + rng := newRand() + r := rng.Intn(10) + var ctx context.Context + var cancel context.CancelFunc + if r < 5 { + ctx, cancel = context.WithTimeout(context.Background(), 1*time.Microsecond) + } else { + ctx, cancel = context.WithTimeout(context.Background(), time.Duration(r*100)*time.Second) + } + defer cancel() + correct := a.Clone().(*Dense) + we, willFailEq := willerr(a, dtype.Number, nilTC) + _, ok := a.Engine().(Diver) + we = we || !ok + + ret, err := Div(a, b, WithContext(ctx)) + if _, ok := err.(NoOpError); ok && r < 5 { + return true // short circuit + } + if err, retEarly := qcErrCheck(t, "Div", a, b, we, err); retEarly { + if err != nil { + return false + } + return true + } + ret, err = Mul(ret, b, UseUnsafe()) + + if !qcEqCheck(t, a.Dtype(), willFailEq, correct.Data(), ret.Data()) { + return false + } + return true + } + if err := quick.Check(inv, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { + t.Errorf("Inv test for Div failed: %v", err) + } +} +func TestPow_context(t *testing.T) { + iden := func(a *Dense) bool { + b := New(Of(a.t), WithShape(a.Shape().Clone()...), WithEngine(a.Engine())) + b.Memset(identityVal(1, a.t)) + rng := newRand() + r := rng.Intn(10) + var ctx context.Context + var cancel context.CancelFunc + if r < 5 { + ctx, cancel = context.WithTimeout(context.Background(), 1*time.Microsecond) + } else { + ctx, cancel = context.WithTimeout(context.Background(), time.Duration(r*100)*time.Second) + } + defer cancel() + correct := a.Clone().(*Dense) + we, willFailEq := willerr(a, dtype.FloatComplex, dtype.Complexes) + _, ok := a.Engine().(Power) + we = we || !ok + + ret, err := Pow(a, b, WithContext(ctx)) + if _, ok := err.(NoOpError); ok && r < 5 { + return true // short circuit + } + if err, retEarly := qcErrCheck(t, "Pow", a, b, we, err); retEarly { + if err != nil { + return false + } + return true + } + + if !qcEqCheck(t, a.Dtype(), willFailEq, correct.Data(), ret.Data()) { + return false + } + return true + } + if err := quick.Check(iden, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { + t.Errorf("Identity test for Pow failed: %v", err) + } + } func TestAddScalar(t *testing.T) { iden1 := func(q *Dense) bool { @@ -615,7 +812,7 @@ func TestAddScalar(t *testing.T) { b := identityVal(0, q.t) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Adder) we = we || !ok @@ -641,7 +838,7 @@ func TestAddScalar(t *testing.T) { a := q.Clone().(*Dense) b := identityVal(0, q.t) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Adder) we = we || !ok @@ -662,6 +859,32 @@ func TestAddScalar(t *testing.T) { t.Errorf("Identity test for Add (scalar as left, tensor as right) failed: %v", err) } + type Foo int + wt1 := func(a *Dense) bool { + b := Foo(0) + ret, err := Add(a, b) + if err == nil { + return false + } + _ = ret + return true + } + if err := quick.Check(wt1, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { + t.Errorf("WrongType test for Add (tensor as left, scalar as right) failed: %v", err) + } + + wt2 := func(a *Dense) bool { + b := Foo(0) + ret, err := Add(b, a) + if err == nil { + return false + } + _ = ret + return true + } + if err := quick.Check(wt2, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { + t.Errorf("WrongType test for Add (tensor as right, scalar as left) failed: %v", err) + } } func TestSubScalar(t *testing.T) { inv1 := func(q *Dense) bool { @@ -669,7 +892,7 @@ func TestSubScalar(t *testing.T) { b := identityVal(0, q.t) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, unsignedTypes) + we, willFailEq := willerr(a, dtype.Number, dtype.Unsigned) _, ok := q.Engine().(Suber) we = we || !ok @@ -695,7 +918,7 @@ func TestSubScalar(t *testing.T) { a := q.Clone().(*Dense) b := identityVal(0, q.t) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, unsignedTypes) + we, willFailEq := willerr(a, dtype.Number, dtype.Unsigned) _, ok := q.Engine().(Suber) we = we || !ok @@ -716,6 +939,32 @@ func TestSubScalar(t *testing.T) { if err := quick.Check(inv2, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { t.Errorf("Inv test for Sub (scalar as left, tensor as right) failed: %v", err) } + type Foo int + wt1 := func(a *Dense) bool { + b := Foo(0) + ret, err := Sub(a, b) + if err == nil { + return false + } + _ = ret + return true + } + if err := quick.Check(wt1, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { + t.Errorf("WrongType test for Sub (tensor as left, scalar as right) failed: %v", err) + } + + wt2 := func(a *Dense) bool { + b := Foo(0) + ret, err := Sub(b, a) + if err == nil { + return false + } + _ = ret + return true + } + if err := quick.Check(wt2, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { + t.Errorf("WrongType test for Sub (tensor as right, scalar as left) failed: %v", err) + } } func TestMulScalar(t *testing.T) { iden1 := func(q *Dense) bool { @@ -723,7 +972,7 @@ func TestMulScalar(t *testing.T) { b := identityVal(1, q.t) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Muler) we = we || !ok @@ -749,7 +998,7 @@ func TestMulScalar(t *testing.T) { a := q.Clone().(*Dense) b := identityVal(1, q.t) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Muler) we = we || !ok @@ -770,6 +1019,32 @@ func TestMulScalar(t *testing.T) { t.Errorf("Identity test for Mul (scalar as left, tensor as right) failed: %v", err) } + type Foo int + wt1 := func(a *Dense) bool { + b := Foo(0) + ret, err := Mul(a, b) + if err == nil { + return false + } + _ = ret + return true + } + if err := quick.Check(wt1, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { + t.Errorf("WrongType test for Mul (tensor as left, scalar as right) failed: %v", err) + } + + wt2 := func(a *Dense) bool { + b := Foo(0) + ret, err := Mul(b, a) + if err == nil { + return false + } + _ = ret + return true + } + if err := quick.Check(wt2, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { + t.Errorf("WrongType test for Mul (tensor as right, scalar as left) failed: %v", err) + } } func TestDivScalar(t *testing.T) { inv1 := func(q *Dense) bool { @@ -777,7 +1052,7 @@ func TestDivScalar(t *testing.T) { b := identityVal(1, q.t) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Diver) we = we || !ok @@ -799,6 +1074,32 @@ func TestDivScalar(t *testing.T) { t.Errorf("Inv test for Div (tensor as left, scalar as right) failed: %v", err) } + type Foo int + wt1 := func(a *Dense) bool { + b := Foo(0) + ret, err := Div(a, b) + if err == nil { + return false + } + _ = ret + return true + } + if err := quick.Check(wt1, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { + t.Errorf("WrongType test for Div (tensor as left, scalar as right) failed: %v", err) + } + + wt2 := func(a *Dense) bool { + b := Foo(0) + ret, err := Div(b, a) + if err == nil { + return false + } + _ = ret + return true + } + if err := quick.Check(wt2, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { + t.Errorf("WrongType test for Div (tensor as right, scalar as left) failed: %v", err) + } } func TestPowScalar(t *testing.T) { iden1 := func(q *Dense) bool { @@ -806,7 +1107,7 @@ func TestPowScalar(t *testing.T) { b := identityVal(1, q.t) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, floatcmplxTypes, complexTypes) + we, willFailEq := willerr(a, dtype.FloatComplex, dtype.Complexes) _, ok := q.Engine().(Power) we = we || !ok @@ -828,6 +1129,32 @@ func TestPowScalar(t *testing.T) { t.Errorf("Identity test for Pow (tensor as left, scalar as right) failed: %v", err) } + type Foo int + wt1 := func(a *Dense) bool { + b := Foo(0) + ret, err := Pow(a, b) + if err == nil { + return false + } + _ = ret + return true + } + if err := quick.Check(wt1, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { + t.Errorf("WrongType test for Pow (tensor as left, scalar as right) failed: %v", err) + } + + wt2 := func(a *Dense) bool { + b := Foo(0) + ret, err := Pow(b, a) + if err == nil { + return false + } + _ = ret + return true + } + if err := quick.Check(wt2, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { + t.Errorf("WrongType test for Pow (tensor as right, scalar as left) failed: %v", err) + } } func TestAddScalar_unsafe(t *testing.T) { iden1 := func(q *Dense) bool { @@ -835,7 +1162,7 @@ func TestAddScalar_unsafe(t *testing.T) { b := identityVal(0, q.t) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Adder) we = we || !ok @@ -854,7 +1181,6 @@ func TestAddScalar_unsafe(t *testing.T) { t.Errorf("Expected ret to be the same as a") return false } - return true } @@ -866,7 +1192,7 @@ func TestAddScalar_unsafe(t *testing.T) { a := q.Clone().(*Dense) b := identityVal(0, q.t) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Adder) we = we || !ok @@ -885,7 +1211,6 @@ func TestAddScalar_unsafe(t *testing.T) { t.Errorf("Expected ret to be the same as a") return false } - return true } if err := quick.Check(iden2, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { @@ -899,7 +1224,7 @@ func TestSubScalar_unsafe(t *testing.T) { b := identityVal(0, q.t) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, unsignedTypes) + we, willFailEq := willerr(a, dtype.Number, dtype.Unsigned) _, ok := q.Engine().(Suber) we = we || !ok @@ -919,7 +1244,6 @@ func TestSubScalar_unsafe(t *testing.T) { t.Errorf("Expected ret to be the same as a") return false } - return true } if err := quick.Check(inv1, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { @@ -930,7 +1254,7 @@ func TestSubScalar_unsafe(t *testing.T) { a := q.Clone().(*Dense) b := identityVal(0, q.t) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, unsignedTypes) + we, willFailEq := willerr(a, dtype.Number, dtype.Unsigned) _, ok := q.Engine().(Suber) we = we || !ok @@ -950,7 +1274,6 @@ func TestSubScalar_unsafe(t *testing.T) { t.Errorf("Expected ret to be the same as a") return false } - return true } if err := quick.Check(inv2, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { @@ -963,7 +1286,7 @@ func TestMulScalar_unsafe(t *testing.T) { b := identityVal(1, q.t) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Muler) we = we || !ok @@ -982,7 +1305,6 @@ func TestMulScalar_unsafe(t *testing.T) { t.Errorf("Expected ret to be the same as a") return false } - return true } @@ -994,7 +1316,7 @@ func TestMulScalar_unsafe(t *testing.T) { a := q.Clone().(*Dense) b := identityVal(1, q.t) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Muler) we = we || !ok @@ -1013,7 +1335,6 @@ func TestMulScalar_unsafe(t *testing.T) { t.Errorf("Expected ret to be the same as a") return false } - return true } if err := quick.Check(iden2, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { @@ -1027,7 +1348,7 @@ func TestDivScalar_unsafe(t *testing.T) { b := identityVal(1, q.t) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Diver) we = we || !ok @@ -1047,7 +1368,6 @@ func TestDivScalar_unsafe(t *testing.T) { t.Errorf("Expected ret to be the same as a") return false } - return true } if err := quick.Check(inv1, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { @@ -1061,7 +1381,7 @@ func TestPowScalar_unsafe(t *testing.T) { b := identityVal(1, q.t) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, floatcmplxTypes, complexTypes) + we, willFailEq := willerr(a, dtype.FloatComplex, dtype.Complexes) _, ok := q.Engine().(Power) we = we || !ok @@ -1080,7 +1400,6 @@ func TestPowScalar_unsafe(t *testing.T) { t.Errorf("Expected ret to be the same as a") return false } - return true } @@ -1096,7 +1415,7 @@ func TestAddScalar_reuse(t *testing.T) { reuse := New(Of(a.t), WithShape(a.Shape().Clone()...)) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Adder) we = we || !ok @@ -1128,7 +1447,7 @@ func TestAddScalar_reuse(t *testing.T) { b := identityVal(0, q.t) reuse := New(Of(a.t), WithShape(a.Shape().Clone()...)) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Adder) we = we || !ok @@ -1162,7 +1481,7 @@ func TestSubScalar_reuse(t *testing.T) { reuse := New(Of(a.t), WithShape(a.Shape().Clone()...)) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, unsignedTypes) + we, willFailEq := willerr(a, dtype.Number, dtype.Unsigned) _, ok := q.Engine().(Suber) we = we || !ok @@ -1194,7 +1513,7 @@ func TestSubScalar_reuse(t *testing.T) { b := identityVal(0, q.t) reuse := New(Of(a.t), WithShape(a.Shape().Clone()...)) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, unsignedTypes) + we, willFailEq := willerr(a, dtype.Number, dtype.Unsigned) _, ok := q.Engine().(Suber) we = we || !ok @@ -1228,7 +1547,7 @@ func TestMulScalar_reuse(t *testing.T) { reuse := New(Of(a.t), WithShape(a.Shape().Clone()...)) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Muler) we = we || !ok @@ -1260,7 +1579,7 @@ func TestMulScalar_reuse(t *testing.T) { b := identityVal(1, q.t) reuse := New(Of(a.t), WithShape(a.Shape().Clone()...)) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Muler) we = we || !ok @@ -1294,7 +1613,7 @@ func TestDivScalar_reuse(t *testing.T) { reuse := New(Of(a.t), WithShape(a.Shape().Clone()...)) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Diver) we = we || !ok @@ -1329,7 +1648,7 @@ func TestPowScalar_reuse(t *testing.T) { reuse := New(Of(a.t), WithShape(a.Shape().Clone()...)) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, floatcmplxTypes, complexTypes) + we, willFailEq := willerr(a, dtype.FloatComplex, dtype.Complexes) _, ok := q.Engine().(Power) we = we || !ok @@ -1366,7 +1685,7 @@ func TestAddScalar_incr(t *testing.T) { correct := a.Clone().(*Dense) incr.Memset(identityVal(100, a.t)) correct.Add(incr, UseUnsafe()) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Adder) we = we || !ok @@ -1395,7 +1714,7 @@ func TestAddScalar_incr(t *testing.T) { correct := a.Clone().(*Dense) incr.Memset(identityVal(100, a.t)) correct.Add(incr, UseUnsafe()) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Adder) we = we || !ok @@ -1426,7 +1745,7 @@ func TestSubScalar_incr(t *testing.T) { correct := a.Clone().(*Dense) incr.Memset(identityVal(100, a.t)) correct.Add(incr, UseUnsafe()) - we, willFailEq := willerr(a, numberTypes, unsignedTypes) + we, willFailEq := willerr(a, dtype.Number, dtype.Unsigned) _, ok := q.Engine().(Suber) we = we || !ok @@ -1458,7 +1777,7 @@ func TestMulScalar_incr(t *testing.T) { correct := a.Clone().(*Dense) incr.Memset(identityVal(100, a.t)) correct.Add(incr, UseUnsafe()) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Muler) we = we || !ok @@ -1487,7 +1806,7 @@ func TestMulScalar_incr(t *testing.T) { correct := a.Clone().(*Dense) incr.Memset(identityVal(100, a.t)) correct.Add(incr, UseUnsafe()) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Muler) we = we || !ok @@ -1518,7 +1837,7 @@ func TestDivScalar_incr(t *testing.T) { correct := a.Clone().(*Dense) incr.Memset(identityVal(100, a.t)) correct.Add(incr, UseUnsafe()) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Diver) we = we || !ok @@ -1550,7 +1869,7 @@ func TestPowScalar_incr(t *testing.T) { correct := a.Clone().(*Dense) incr.Memset(identityVal(100, a.t)) correct.Add(incr, UseUnsafe()) - we, willFailEq := willerr(a, floatcmplxTypes, complexTypes) + we, willFailEq := willerr(a, dtype.FloatComplex, dtype.Complexes) _, ok := q.Engine().(Power) we = we || !ok @@ -1573,3 +1892,327 @@ func TestPowScalar_incr(t *testing.T) { } } +func TestAddScalar_context(t *testing.T) { + iden1 := func(q *Dense) bool { + a := q.Clone().(*Dense) + b := identityVal(0, q.t) + rng := newRand() + r := rng.Intn(10) + var ctx context.Context + var cancel context.CancelFunc + if r < 5 { + ctx, cancel = context.WithTimeout(context.Background(), 1*time.Microsecond) + } else { + ctx, cancel = context.WithTimeout(context.Background(), time.Duration(r*100)*time.Second) + } + defer cancel() + + correct := a.Clone().(*Dense) + we, willFailEq := willerr(a, dtype.Number, nilTC) + _, ok := q.Engine().(Adder) + we = we || !ok + + ret, err := Add(a, b, WithContext(ctx)) + if _, ok := err.(NoOpError); ok && r < 5 { + return true // short circuit + } + if err, retEarly := qcErrCheck(t, "Add", a, b, we, err); retEarly { + if err != nil { + return false + } + return true + } + + if !qcEqCheck(t, a.Dtype(), willFailEq, correct.Data(), ret.Data()) { + return false + } + return true + } + + if err := quick.Check(iden1, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { + t.Errorf("Identity test for Add (tensor as left, scalar as right) failed: %v", err) + } + + iden2 := func(q *Dense) bool { + a := q.Clone().(*Dense) + b := identityVal(0, q.t) + rng := newRand() + r := rng.Intn(10) + var ctx context.Context + var cancel context.CancelFunc + if r < 5 { + ctx, cancel = context.WithTimeout(context.Background(), 1*time.Microsecond) + } else { + ctx, cancel = context.WithTimeout(context.Background(), time.Duration(r*100)*time.Second) + } + defer cancel() + correct := a.Clone().(*Dense) + we, willFailEq := willerr(a, dtype.Number, nilTC) + _, ok := q.Engine().(Adder) + we = we || !ok + + ret, err := Add(b, a, WithContext(ctx)) + if _, ok := err.(NoOpError); ok && r < 5 { + return true // short circuit + } + if err, retEarly := qcErrCheck(t, "Add", a, b, we, err); retEarly { + if err != nil { + return false + } + return true + } + + if !qcEqCheck(t, a.Dtype(), willFailEq, correct.Data(), ret.Data()) { + return false + } + return true + } + if err := quick.Check(iden2, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { + t.Errorf("Identity test for Add (scalar as left, tensor as right) failed: %v", err) + } + +} +func TestSubScalar_context(t *testing.T) { + inv1 := func(q *Dense) bool { + a := q.Clone().(*Dense) + b := identityVal(0, q.t) + rng := newRand() + r := rng.Intn(10) + var ctx context.Context + var cancel context.CancelFunc + if r < 5 { + ctx, cancel = context.WithTimeout(context.Background(), 1*time.Microsecond) + } else { + ctx, cancel = context.WithTimeout(context.Background(), time.Duration(r*100)*time.Second) + } + defer cancel() + + correct := a.Clone().(*Dense) + we, willFailEq := willerr(a, dtype.Number, dtype.Unsigned) + _, ok := q.Engine().(Suber) + we = we || !ok + + ret, err := Sub(a, b, WithContext(ctx)) + if _, ok := err.(NoOpError); ok && r < 5 { + return true // short circuit + } + if err, retEarly := qcErrCheck(t, "SubVS", a, b, we, err); retEarly { + if err != nil { + return false + } + return true + } + ret, err = Add(ret, b, UseUnsafe()) + + if !qcEqCheck(t, a.Dtype(), willFailEq, correct.Data(), ret.Data()) { + return false + } + return true + } + if err := quick.Check(inv1, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { + t.Errorf("Inv test for Sub (tensor as left, scalar as right) failed: %v", err) + } + + inv2 := func(q *Dense) bool { + a := q.Clone().(*Dense) + b := identityVal(0, q.t) + rng := newRand() + r := rng.Intn(10) + var ctx context.Context + var cancel context.CancelFunc + if r < 5 { + ctx, cancel = context.WithTimeout(context.Background(), 1*time.Microsecond) + } else { + ctx, cancel = context.WithTimeout(context.Background(), time.Duration(r*100)*time.Second) + } + defer cancel() + correct := a.Clone().(*Dense) + we, willFailEq := willerr(a, dtype.Number, dtype.Unsigned) + _, ok := q.Engine().(Suber) + we = we || !ok + + ret, err := Sub(b, a, WithContext(ctx)) + if _, ok := err.(NoOpError); ok && r < 5 { + return true // short circuit + } + if err, retEarly := qcErrCheck(t, "SubSV", a, b, we, err); retEarly { + if err != nil { + return false + } + return true + } + ret, err = Sub(b, ret, UseUnsafe()) + + if !qcEqCheck(t, a.Dtype(), willFailEq, correct.Data(), ret.Data()) { + return false + } + return true + } + if err := quick.Check(inv2, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { + t.Errorf("Inv test for Sub (scalar as left, tensor as right) failed: %v", err) + } +} +func TestMulScalar_context(t *testing.T) { + iden1 := func(q *Dense) bool { + a := q.Clone().(*Dense) + b := identityVal(1, q.t) + rng := newRand() + r := rng.Intn(10) + var ctx context.Context + var cancel context.CancelFunc + if r < 5 { + ctx, cancel = context.WithTimeout(context.Background(), 1*time.Microsecond) + } else { + ctx, cancel = context.WithTimeout(context.Background(), time.Duration(r*100)*time.Second) + } + defer cancel() + + correct := a.Clone().(*Dense) + we, willFailEq := willerr(a, dtype.Number, nilTC) + _, ok := q.Engine().(Muler) + we = we || !ok + + ret, err := Mul(a, b, WithContext(ctx)) + if _, ok := err.(NoOpError); ok && r < 5 { + return true // short circuit + } + if err, retEarly := qcErrCheck(t, "Mul", a, b, we, err); retEarly { + if err != nil { + return false + } + return true + } + + if !qcEqCheck(t, a.Dtype(), willFailEq, correct.Data(), ret.Data()) { + return false + } + return true + } + + if err := quick.Check(iden1, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { + t.Errorf("Identity test for Mul (tensor as left, scalar as right) failed: %v", err) + } + + iden2 := func(q *Dense) bool { + a := q.Clone().(*Dense) + b := identityVal(1, q.t) + rng := newRand() + r := rng.Intn(10) + var ctx context.Context + var cancel context.CancelFunc + if r < 5 { + ctx, cancel = context.WithTimeout(context.Background(), 1*time.Microsecond) + } else { + ctx, cancel = context.WithTimeout(context.Background(), time.Duration(r*100)*time.Second) + } + defer cancel() + correct := a.Clone().(*Dense) + we, willFailEq := willerr(a, dtype.Number, nilTC) + _, ok := q.Engine().(Muler) + we = we || !ok + + ret, err := Mul(b, a, WithContext(ctx)) + if _, ok := err.(NoOpError); ok && r < 5 { + return true // short circuit + } + if err, retEarly := qcErrCheck(t, "Mul", a, b, we, err); retEarly { + if err != nil { + return false + } + return true + } + + if !qcEqCheck(t, a.Dtype(), willFailEq, correct.Data(), ret.Data()) { + return false + } + return true + } + if err := quick.Check(iden2, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { + t.Errorf("Identity test for Mul (scalar as left, tensor as right) failed: %v", err) + } + +} +func TestDivScalar_context(t *testing.T) { + inv1 := func(q *Dense) bool { + a := q.Clone().(*Dense) + b := identityVal(1, q.t) + rng := newRand() + r := rng.Intn(10) + var ctx context.Context + var cancel context.CancelFunc + if r < 5 { + ctx, cancel = context.WithTimeout(context.Background(), 1*time.Microsecond) + } else { + ctx, cancel = context.WithTimeout(context.Background(), time.Duration(r*100)*time.Second) + } + defer cancel() + + correct := a.Clone().(*Dense) + we, willFailEq := willerr(a, dtype.Number, nilTC) + _, ok := q.Engine().(Diver) + we = we || !ok + + ret, err := Div(a, b, WithContext(ctx)) + if _, ok := err.(NoOpError); ok && r < 5 { + return true // short circuit + } + if err, retEarly := qcErrCheck(t, "DivVS", a, b, we, err); retEarly { + if err != nil { + return false + } + return true + } + ret, err = Mul(ret, b, UseUnsafe()) + + if !qcEqCheck(t, a.Dtype(), willFailEq, correct.Data(), ret.Data()) { + return false + } + return true + } + if err := quick.Check(inv1, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { + t.Errorf("Inv test for Div (tensor as left, scalar as right) failed: %v", err) + } + +} +func TestPowScalar_context(t *testing.T) { + iden1 := func(q *Dense) bool { + a := q.Clone().(*Dense) + b := identityVal(1, q.t) + rng := newRand() + r := rng.Intn(10) + var ctx context.Context + var cancel context.CancelFunc + if r < 5 { + ctx, cancel = context.WithTimeout(context.Background(), 1*time.Microsecond) + } else { + ctx, cancel = context.WithTimeout(context.Background(), time.Duration(r*100)*time.Second) + } + defer cancel() + + correct := a.Clone().(*Dense) + we, willFailEq := willerr(a, dtype.FloatComplex, dtype.Complexes) + _, ok := q.Engine().(Power) + we = we || !ok + + ret, err := Pow(a, b, WithContext(ctx)) + if _, ok := err.(NoOpError); ok && r < 5 { + return true // short circuit + } + if err, retEarly := qcErrCheck(t, "Pow", a, b, we, err); retEarly { + if err != nil { + return false + } + return true + } + + if !qcEqCheck(t, a.Dtype(), willFailEq, correct.Data(), ret.Data()) { + return false + } + return true + } + + if err := quick.Check(iden1, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { + t.Errorf("Identity test for Pow (tensor as left, scalar as right) failed: %v", err) + } + +} diff --git a/api_arith_test.go b/api_arith_test.go index d7bd9a5..3a3cf67 100644 --- a/api_arith_test.go +++ b/api_arith_test.go @@ -1,154 +1,575 @@ -package tensor - -import ( - "log" - "math/rand" - "testing" - "testing/quick" - "time" - - "github.com/stretchr/testify/assert" -) - -// This file contains the tests for API functions that aren't generated by genlib - -func TestMod(t *testing.T) { - a := New(WithBacking([]float64{1, 2, 3, 4})) - b := New(WithBacking([]float64{1, 1, 1, 1})) - var correct interface{} = []float64{0, 0, 0, 0} - - // vec-vec - res, err := Mod(a, b) - if err != nil { - t.Fatalf("Error: %v", err) - } - assert.Equal(t, correct, res.Data()) - - // scalar - if res, err = Mod(a, 1.0); err != nil { - t.Fatal("Error: %v", err) - } - assert.Equal(t, correct, res.Data()) -} - -func TestFMA(t *testing.T) { - same := func(q *Dense) bool { - a := q.Clone().(*Dense) - x := q.Clone().(*Dense) - y := New(Of(q.Dtype()), WithShape(q.Shape().Clone()...)) - y.Memset(identityVal(100, q.Dtype())) - WithEngine(q.Engine())(y) - y2 := y.Clone().(*Dense) - - we, willFailEq := willerr(a, numberTypes, nil) - // _, ok1 := q.Engine().(FMAer) - // _, ok2 := q.Engine().(Muler) - // _, ok3 := q.Engine().(Adder) - // we = we || (!ok1 && (!ok2 || !ok3)) - - f, err := FMA(a, x, y) - if err, retEarly := qcErrCheck(t, "FMA#1", a, x, we, err); retEarly { - if err != nil { - log.Printf("q.Engine() %T", q.Engine()) - return false - } - return true - } - - we, _ = willerr(a, numberTypes, nil) - _, ok := a.Engine().(Muler) - we = we || !ok - wi, err := Mul(a, x, WithIncr(y2)) - if err, retEarly := qcErrCheck(t, "FMA#2", a, x, we, err); retEarly { - if err != nil { - return false - } - return true - } - return qcEqCheck(t, q.Dtype(), willFailEq, wi, f) - } - r := rand.New(rand.NewSource(time.Now().UnixNano())) - if err := quick.Check(same, &quick.Config{Rand: r}); err != nil { - t.Error(err) - } - - // specific engines - var eng Engine - - // FLOAT64 ENGINE - - // vec-vec - eng = Float64Engine{} - a := New(WithBacking(Range(Float64, 0, 100)), WithEngine(eng)) - x := New(WithBacking(Range(Float64, 1, 101)), WithEngine(eng)) - y := New(Of(Float64), WithShape(100), WithEngine(eng)) - - f, err := FMA(a, x, y) - if err != nil { - t.Fatal(err) - } - - a2 := New(WithBacking(Range(Float64, 0, 100))) - x2 := New(WithBacking(Range(Float64, 1, 101))) - y2 := New(Of(Float64), WithShape(100)) - f2, err := Mul(a2, x2, WithIncr(y2)) - if err != nil { - t.Fatal(err) - } - - assert.Equal(t, f.Data(), f2.Data()) - - // vec-scalar - a = New(WithBacking(Range(Float64, 0, 100)), WithEngine(eng)) - y = New(Of(Float64), WithShape(100)) - - if f, err = FMA(a, 2.0, y); err != nil { - t.Fatal(err) - } - - a2 = New(WithBacking(Range(Float64, 0, 100))) - y2 = New(Of(Float64), WithShape(100)) - if f2, err = Mul(a2, 2.0, WithIncr(y2)); err != nil { - t.Fatal(err) - } - - assert.Equal(t, f.Data(), f2.Data()) - - // FLOAT32 engine - eng = Float32Engine{} - a = New(WithBacking(Range(Float32, 0, 100)), WithEngine(eng)) - x = New(WithBacking(Range(Float32, 1, 101)), WithEngine(eng)) - y = New(Of(Float32), WithShape(100), WithEngine(eng)) - - f, err = FMA(a, x, y) - if err != nil { - t.Fatal(err) - } - - a2 = New(WithBacking(Range(Float32, 0, 100))) - x2 = New(WithBacking(Range(Float32, 1, 101))) - y2 = New(Of(Float32), WithShape(100)) - f2, err = Mul(a2, x2, WithIncr(y2)) - if err != nil { - t.Fatal(err) - } - - assert.Equal(t, f.Data(), f2.Data()) - - // vec-scalar - a = New(WithBacking(Range(Float32, 0, 100)), WithEngine(eng)) - y = New(Of(Float32), WithShape(100)) - - if f, err = FMA(a, float32(2), y); err != nil { - t.Fatal(err) - } - - a2 = New(WithBacking(Range(Float32, 0, 100))) - y2 = New(Of(Float32), WithShape(100)) - if f2, err = Mul(a2, float32(2), WithIncr(y2)); err != nil { - t.Fatal(err) - } - - assert.Equal(t, f.Data(), f2.Data()) - -} +package tensor + +import ( + "log" + "math/rand" + "testing" + "testing/quick" + "time" + + "github.com/stretchr/testify/assert" + "gorgonia.org/dtype" +) + +// This file contains the tests for API functions that aren't generated by genlib + +func TestMod(t *testing.T) { + a := New(WithBacking([]float64{1, 2, 3, 4})) + b := New(WithBacking([]float64{1, 1, 1, 1})) + var correct interface{} = []float64{0, 0, 0, 0} + + // vec-vec + res, err := Mod(a, b) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) + + // scalar + if res, err = Mod(a, 1.0); err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) +} + +func TestFMA(t *testing.T) { + same := func(q *Dense) bool { + a := q.Clone().(*Dense) + x := q.Clone().(*Dense) + y := New(Of(q.Dtype()), WithShape(q.Shape().Clone()...)) + y.Memset(identityVal(100, q.Dtype())) + WithEngine(q.Engine())(y) + y2 := y.Clone().(*Dense) + + we, willFailEq := willerr(a, dtype.Number, nilTC) + _, ok1 := q.Engine().(FMAer) + _, ok2 := q.Engine().(Muler) + _, ok3 := q.Engine().(Adder) + we = we || (!ok1 && (!ok2 || !ok3)) + + f, err := FMA(a, x, y) + if err, retEarly := qcErrCheck(t, "FMA#1", a, x, we, err); retEarly { + if err != nil { + log.Printf("q.Engine() %T", q.Engine()) + return false + } + return true + } + + we, _ = willerr(a, dtype.Number, nilTC) + _, ok := a.Engine().(Muler) + we = we || !ok + wi, err := Mul(a, x, WithIncr(y2)) + if err, retEarly := qcErrCheck(t, "FMA#2", a, x, we, err); retEarly { + if err != nil { + return false + } + return true + } + return qcEqCheck(t, q.Dtype(), willFailEq, wi, f) + } + r := rand.New(rand.NewSource(time.Now().UnixNano())) + if err := quick.Check(same, &quick.Config{Rand: r}); err != nil { + t.Error(err) + } + + // specific engines + var eng Engine + + // FLOAT64 ENGINE + + // vec-vec + eng = Float64Engine{} + a := New(WithBacking(Range(Float64, 0, 100)), WithEngine(eng)) + x := New(WithBacking(Range(Float64, 1, 101)), WithEngine(eng)) + y := New(Of(Float64), WithShape(100), WithEngine(eng)) + + f, err := FMA(a, x, y) + if err != nil { + t.Fatal(err) + } + + a2 := New(WithBacking(Range(Float64, 0, 100))) + x2 := New(WithBacking(Range(Float64, 1, 101))) + y2 := New(Of(Float64), WithShape(100)) + f2, err := Mul(a2, x2, WithIncr(y2)) + if err != nil { + t.Fatal(err) + } + + assert.Equal(t, f.Data(), f2.Data()) + + // vec-scalar + a = New(WithBacking(Range(Float64, 0, 100)), WithEngine(eng)) + y = New(Of(Float64), WithShape(100)) + + if f, err = FMA(a, 2.0, y); err != nil { + t.Fatal(err) + } + + a2 = New(WithBacking(Range(Float64, 0, 100))) + y2 = New(Of(Float64), WithShape(100)) + if f2, err = Mul(a2, 2.0, WithIncr(y2)); err != nil { + t.Fatal(err) + } + + assert.Equal(t, f.Data(), f2.Data()) + + // FLOAT32 engine + eng = Float32Engine{} + a = New(WithBacking(Range(Float32, 0, 100)), WithEngine(eng)) + x = New(WithBacking(Range(Float32, 1, 101)), WithEngine(eng)) + y = New(Of(Float32), WithShape(100), WithEngine(eng)) + + f, err = FMA(a, x, y) + if err != nil { + t.Fatal(err) + } + + a2 = New(WithBacking(Range(Float32, 0, 100))) + x2 = New(WithBacking(Range(Float32, 1, 101))) + y2 = New(Of(Float32), WithShape(100)) + f2, err = Mul(a2, x2, WithIncr(y2)) + if err != nil { + t.Fatal(err) + } + + assert.Equal(t, f.Data(), f2.Data()) + + // vec-scalar + a = New(WithBacking(Range(Float32, 0, 100)), WithEngine(eng)) + y = New(Of(Float32), WithShape(100)) + + if f, err = FMA(a, float32(2), y); err != nil { + t.Fatal(err) + } + + a2 = New(WithBacking(Range(Float32, 0, 100))) + y2 = New(Of(Float32), WithShape(100)) + if f2, err = Mul(a2, float32(2), WithIncr(y2)); err != nil { + t.Fatal(err) + } + + assert.Equal(t, f.Data(), f2.Data()) + +} + +func TestMulScalarScalar(t *testing.T) { + // scalar-scalar + a := New(WithBacking([]float64{2})) + b := New(WithBacking([]float64{3})) + var correct interface{} = 6.0 + + res, err := Mul(a, b) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) + + // Test commutativity + res, err = Mul(b, a) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) + + // scalar-tensor + a = New(WithBacking([]float64{3, 2})) + b = New(WithBacking([]float64{2})) + correct = []float64{6, 4} + + res, err = Mul(a, b) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) + + // Test commutativity + res, err = Mul(b, a) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) + + // tensor - tensor + a = New(WithBacking([]float64{3, 5})) + b = New(WithBacking([]float64{7, 2})) + correct = []float64{21, 10} + + res, err = Mul(a, b) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) + + // Test commutativity + res, err = Mul(b, a) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) + + // Interface - tensor + ai := 2.0 + b = NewDense(Float64, Shape{1, 1}, WithBacking([]float64{3})) + correct = []float64{6.0} + + res, err = Mul(ai, b) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) + + // Commutativity + res, err = Mul(b, ai) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) +} + +func TestDivScalarScalar(t *testing.T) { + // scalar-scalar + a := New(WithBacking([]float64{6})) + b := New(WithBacking([]float64{2})) + var correct interface{} = 3.0 + + res, err := Div(a, b) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) + + // scalar-tensor + a = New(WithBacking([]float64{6, 4})) + b = New(WithBacking([]float64{2})) + correct = []float64{3, 2} + + res, err = Div(a, b) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) + + // tensor-scalar + a = New(WithBacking([]float64{6})) + b = New(WithBacking([]float64{3, 2})) + correct = []float64{2, 3} + + res, err = Div(a, b) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) + + // tensor - tensor + a = New(WithBacking([]float64{21, 10})) + b = New(WithBacking([]float64{7, 2})) + correct = []float64{3, 5} + + res, err = Div(a, b) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) + + // interface-scalar + ai := 6.0 + b = New(WithBacking([]float64{2})) + correct = 3.0 + + res, err = Div(ai, b) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) + + // scalar-interface + a = New(WithBacking([]float64{6})) + bi := 2.0 + correct = 3.0 + + res, err = Div(a, bi) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) +} + +func TestAddScalarScalar(t *testing.T) { + // scalar-scalar + a := New(WithBacking([]float64{2})) + b := New(WithBacking([]float64{3})) + var correct interface{} = 5.0 + + res, err := Add(a, b) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) + + // Test commutativity + res, err = Add(b, a) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) + + // scalar-tensor + a = New(WithBacking([]float64{3, 2})) + b = New(WithBacking([]float64{2})) + correct = []float64{5, 4} + + res, err = Add(a, b) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) + + // Test commutativity + res, err = Add(b, a) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) + + // tensor - tensor + a = New(WithBacking([]float64{3, 5})) + b = New(WithBacking([]float64{7, 2})) + correct = []float64{10, 7} + + res, err = Add(a, b) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) + + // Test commutativity + res, err = Add(b, a) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) + + // interface-scalar + ai := 2.0 + b = New(WithBacking([]float64{3})) + correct = 5.0 + + res, err = Add(ai, b) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) + + // Test commutativity + res, err = Add(b, ai) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) +} + +func TestSubScalarScalar(t *testing.T) { + // scalar-scalar + a := New(WithBacking([]float64{6})) + b := New(WithBacking([]float64{2})) + var correct interface{} = 4.0 + + res, err := Sub(a, b) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) + + // scalar-tensor + a = New(WithBacking([]float64{6, 4})) + b = New(WithBacking([]float64{2})) + correct = []float64{4, 2} + + res, err = Sub(a, b) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) + + // tensor-scalar + a = New(WithBacking([]float64{6})) + b = New(WithBacking([]float64{3, 2})) + correct = []float64{3, 4} + + res, err = Sub(a, b) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) + + // tensor - tensor + a = New(WithBacking([]float64{21, 10})) + b = New(WithBacking([]float64{7, 2})) + correct = []float64{14, 8} + + res, err = Sub(a, b) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) + + // interface-scalar + ai := 6.0 + b = New(WithBacking([]float64{2})) + correct = 4.0 + + res, err = Sub(ai, b) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) + + // scalar-interface + a = New(WithBacking([]float64{6})) + bi := 2.0 + correct = 4.0 + + res, err = Sub(a, bi) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) +} + +func TestModScalarScalar(t *testing.T) { + // scalar-scalar + a := New(WithBacking([]float64{5})) + b := New(WithBacking([]float64{2})) + var correct interface{} = 1.0 + + res, err := Mod(a, b) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) + + // scalar-tensor + a = New(WithBacking([]float64{5, 4})) + b = New(WithBacking([]float64{2})) + correct = []float64{1, 0} + + res, err = Mod(a, b) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) + + // tensor-scalar + a = New(WithBacking([]float64{5})) + b = New(WithBacking([]float64{3, 2})) + correct = []float64{2, 1} + + res, err = Mod(a, b) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) + + // tensor - tensor + a = New(WithBacking([]float64{22, 10})) + b = New(WithBacking([]float64{7, 2})) + correct = []float64{1, 0} + + res, err = Mod(a, b) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) + + // interface-scalar + ai := 5.0 + b = New(WithBacking([]float64{2})) + correct = 1.0 + + res, err = Mod(ai, b) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) + + // scalar-interface + a = New(WithBacking([]float64{5})) + bi := 2.0 + correct = 1.0 + + res, err = Mod(a, bi) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) +} + +func TestPowScalarScalar(t *testing.T) { + // scalar-scalar + a := New(WithBacking([]float64{6})) + b := New(WithBacking([]float64{2})) + var correct interface{} = 36.0 + + res, err := Pow(a, b) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) + + // scalar-tensor + a = New(WithBacking([]float64{6, 4})) + b = New(WithBacking([]float64{2})) + correct = []float64{36, 16} + + res, err = Pow(a, b) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) + + // tensor-scalar + a = New(WithBacking([]float64{6})) + b = New(WithBacking([]float64{3, 2})) + correct = []float64{216, 36} + + res, err = Pow(a, b) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) + + // tensor - tensor + a = New(WithBacking([]float64{3, 10})) + b = New(WithBacking([]float64{7, 2})) + correct = []float64{2187, 100} + + res, err = Pow(a, b) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) + + // interface-scalar + ai := 6.0 + b = New(WithBacking([]float64{2})) + correct = 36.0 + + res, err = Pow(ai, b) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) + + // scalar-interface + a = New(WithBacking([]float64{6})) + bi := 2.0 + correct = 36.0 + + res, err = Pow(a, bi) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) +} diff --git a/api_cmp.go b/api_cmp.go index e79398e..b2ac050 100644 --- a/api_cmp.go +++ b/api_cmp.go @@ -1,6 +1,8 @@ package tensor -import "github.com/pkg/errors" +import ( + "github.com/pkg/errors" +) // public API for comparison ops @@ -17,12 +19,30 @@ func Lt(a, b interface{}, opts ...FuncOpt) (retVal Tensor, err error) { lter, ok = at.Engine().(Lter) switch bt := b.(type) { case Tensor: - if !ok { - if lter, ok = bt.Engine().(Lter); !ok { - return nil, errors.Errorf("Neither operands have engines that support Lt") + if !bt.Shape().IsScalar() && !at.Shape().IsScalar() { // non-scalar Tensor comparison + if !ok { + if lter, ok = bt.Engine().(Lter); !ok { + return nil, errors.Errorf("Neither operands have engines that support Lt") + } + } + + return lter.Lt(at, bt, opts...) + } else { + var leftTensor bool + if !bt.Shape().IsScalar() { + leftTensor = false // a Scalar-Tensor * b Tensor + tmp := at + at = bt + bt = tmp + } else { + leftTensor = true // a Tensor * b Scalar-Tensor } + + if !ok { + return nil, errors.Errorf("Engine does not support Lt") + } + return lter.LtScalar(at, bt, leftTensor, opts...) } - return lter.Lt(at, bt, opts...) default: if !ok { return nil, errors.Errorf("Engine does not support Lt") @@ -55,12 +75,29 @@ func Gt(a, b interface{}, opts ...FuncOpt) (retVal Tensor, err error) { gter, ok = at.Engine().(Gter) switch bt := b.(type) { case Tensor: - if !ok { - if gter, ok = bt.Engine().(Gter); !ok { - return nil, errors.Errorf("Neither operands have engines that support Gt") + if !bt.Shape().IsScalar() && !at.Shape().IsScalar() { // non-scalar Tensor comparison + if !ok { + if gter, ok = bt.Engine().(Gter); !ok { + return nil, errors.Errorf("Neither operands have engines that support Gt") + } + } + return gter.Gt(at, bt, opts...) + } else { + var leftTensor bool + if !bt.Shape().IsScalar() { + leftTensor = false // a Scalar-Tensor * b Tensor + tmp := at + at = bt + bt = tmp + } else { + leftTensor = true // a Tensor * b Scalar-Tensor } + + if !ok { + return nil, errors.Errorf("Engine does not support Gt") + } + return gter.GtScalar(at, bt, leftTensor, opts...) } - return gter.Gt(at, bt, opts...) default: if !ok { return nil, errors.Errorf("Engine does not support Gt") @@ -93,12 +130,30 @@ func Lte(a, b interface{}, opts ...FuncOpt) (retVal Tensor, err error) { lteer, ok = at.Engine().(Lteer) switch bt := b.(type) { case Tensor: - if !ok { - if lteer, ok = bt.Engine().(Lteer); !ok { - return nil, errors.Errorf("Neither operands have engines that support Lte") + if !bt.Shape().IsScalar() && !at.Shape().IsScalar() { // non-scalar Tensor comparison + if !ok { + if lteer, ok = bt.Engine().(Lteer); !ok { + return nil, errors.Errorf("Neither operands have engines that support Lte") + } + } + return lteer.Lte(at, bt, opts...) + } else { + var leftTensor bool + if !bt.Shape().IsScalar() { + leftTensor = false // a Scalar-Tensor * b Tensor + tmp := at + at = bt + bt = tmp + } else { + leftTensor = true // a Tensor * b Scalar-Tensor + } + + if !ok { + return nil, errors.Errorf("Engine does not support Lte") } + return lteer.LteScalar(at, bt, leftTensor, opts...) } - return lteer.Lte(at, bt, opts...) + default: if !ok { return nil, errors.Errorf("Engine does not support Lte") @@ -131,12 +186,29 @@ func Gte(a, b interface{}, opts ...FuncOpt) (retVal Tensor, err error) { gteer, ok = at.Engine().(Gteer) switch bt := b.(type) { case Tensor: - if !ok { - if gteer, ok = bt.Engine().(Gteer); !ok { - return nil, errors.Errorf("Neither operands have engines that support Gte") + if !bt.Shape().IsScalar() && !at.Shape().IsScalar() { // non-scalar Tensor comparison + if !ok { + if gteer, ok = bt.Engine().(Gteer); !ok { + return nil, errors.Errorf("Neither operands have engines that support Gte") + } } + return gteer.Gte(at, bt, opts...) + } else { + var leftTensor bool + if !bt.Shape().IsScalar() { + leftTensor = false // a Scalar-Tensor * b Tensor + tmp := at + at = bt + bt = tmp + } else { + leftTensor = true // a Tensor * b Scalar-Tensor + } + + if !ok { + return nil, errors.Errorf("Engine does not support Gte") + } + return gteer.GteScalar(at, bt, leftTensor, opts...) } - return gteer.Gte(at, bt, opts...) default: if !ok { return nil, errors.Errorf("Engine does not support Gte") @@ -169,12 +241,30 @@ func ElEq(a, b interface{}, opts ...FuncOpt) (retVal Tensor, err error) { eleqer, ok = at.Engine().(ElEqer) switch bt := b.(type) { case Tensor: - if !ok { - if eleqer, ok = bt.Engine().(ElEqer); !ok { - return nil, errors.Errorf("Neither operands have engines that support ElEq") + if !bt.Shape().IsScalar() && !at.Shape().IsScalar() { // non-scalar Tensor comparison + if !ok { + if eleqer, ok = bt.Engine().(ElEqer); !ok { + return nil, errors.Errorf("Neither operands have engines that support ElEq") + } + } + return eleqer.ElEq(at, bt, opts...) + } else { + var leftTensor bool + if !bt.Shape().IsScalar() { + leftTensor = false // a Scalar-Tensor * b Tensor + tmp := at + at = bt + bt = tmp + } else { + leftTensor = true // a Tensor * b Scalar-Tensor } + + if !ok { + return nil, errors.Errorf("Engine does not support ElEq") + } + return eleqer.EqScalar(at, bt, leftTensor, opts...) } - return eleqer.ElEq(at, bt, opts...) + default: if !ok { return nil, errors.Errorf("Engine does not support ElEq") @@ -207,12 +297,26 @@ func ElNe(a, b interface{}, opts ...FuncOpt) (retVal Tensor, err error) { eleqer, ok = at.Engine().(ElEqer) switch bt := b.(type) { case Tensor: - if !ok { - if eleqer, ok = bt.Engine().(ElEqer); !ok { - return nil, errors.Errorf("Neither operands have engines that support ElEq") + if !bt.Shape().IsScalar() && !at.Shape().IsScalar() { // non-scalar Tensor comparison + if !ok { + if eleqer, ok = bt.Engine().(ElEqer); !ok { + return nil, errors.Errorf("Neither operands have engines that support ElEq") + } + } + return eleqer.ElNe(at, bt, opts...) + } else { + var leftTensor bool + if !bt.Shape().IsScalar() { + leftTensor = false + at, bt = bt, at + } else { + leftTensor = true + } + if !ok { + return nil, errors.Errorf("Engine does not support ElNE") } + return eleqer.NeScalar(at, bt, leftTensor, opts...) } - return eleqer.ElNe(at, bt, opts...) default: if !ok { return nil, errors.Errorf("Engine does not support ElEq") diff --git a/api_cmp_generated_test.go b/api_cmp_generated_test.go index 163ae5c..4a612d8 100644 --- a/api_cmp_generated_test.go +++ b/api_cmp_generated_test.go @@ -1,16 +1,18 @@ -// Code generated by genlib2. DO NOT EDIT. - package tensor import ( "reflect" "testing" "testing/quick" + + "gorgonia.org/dtype" ) +// Code generated by genlib2. DO NOT EDIT. + func TestGt(t *testing.T) { transFn := func(q *Dense) bool { - we, _ := willerr(q, ordTypes, nil) + we, _ := willerr(q, dtype.Ord, nilTC) _, ok := q.Engine().(Gter) we = we || !ok @@ -62,13 +64,13 @@ func TestGt(t *testing.T) { return true } if err := quick.Check(transFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Transitivity test for Gt failed: %v", err) + t.Errorf("Transitivity test for Gt failed: %v", err) } } func TestGte(t *testing.T) { transFn := func(q *Dense) bool { - we, _ := willerr(q, ordTypes, nil) + we, _ := willerr(q, dtype.Ord, nilTC) _, ok := q.Engine().(Gteer) we = we || !ok @@ -120,13 +122,13 @@ func TestGte(t *testing.T) { return true } if err := quick.Check(transFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Transitivity test for Gte failed: %v", err) + t.Errorf("Transitivity test for Gte failed: %v", err) } } func TestLt(t *testing.T) { transFn := func(q *Dense) bool { - we, _ := willerr(q, ordTypes, nil) + we, _ := willerr(q, dtype.Ord, nilTC) _, ok := q.Engine().(Lter) we = we || !ok @@ -178,13 +180,13 @@ func TestLt(t *testing.T) { return true } if err := quick.Check(transFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Transitivity test for Lt failed: %v", err) + t.Errorf("Transitivity test for Lt failed: %v", err) } } func TestLte(t *testing.T) { transFn := func(q *Dense) bool { - we, _ := willerr(q, ordTypes, nil) + we, _ := willerr(q, dtype.Ord, nilTC) _, ok := q.Engine().(Lteer) we = we || !ok @@ -236,13 +238,13 @@ func TestLte(t *testing.T) { return true } if err := quick.Check(transFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Transitivity test for Lte failed: %v", err) + t.Errorf("Transitivity test for Lte failed: %v", err) } } func TestEq(t *testing.T) { transFn := func(q *Dense) bool { - we, _ := willerr(q, eqTypes, nil) + we, _ := willerr(q, dtype.Eq, nilTC) _, ok := q.Engine().(ElEqer) we = we || !ok @@ -294,11 +296,11 @@ func TestEq(t *testing.T) { return true } if err := quick.Check(transFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Transitivity test for ElEq failed: %v", err) + t.Errorf("Transitivity test for ElEq failed: %v", err) } symFn := func(q *Dense) bool { - we, _ := willerr(q, eqTypes, nil) + we, _ := willerr(q, dtype.Eq, nilTC) _, ok := q.Engine().(ElEqer) we = we || !ok @@ -328,12 +330,12 @@ func TestEq(t *testing.T) { } if err := quick.Check(symFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Transitivity test for ElEq failed: %v", err) + t.Errorf("Transitivity test for ElEq failed: %v", err) } } func TestNe(t *testing.T) { symFn := func(q *Dense) bool { - we, _ := willerr(q, eqTypes, nil) + we, _ := willerr(q, dtype.Eq, nilTC) _, ok := q.Engine().(ElEqer) we = we || !ok @@ -363,16 +365,16 @@ func TestNe(t *testing.T) { } if err := quick.Check(symFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Transitivity test for ElNe failed: %v", err) + t.Errorf("Transitivity test for ElNe failed: %v", err) } } func TestGt_assame(t *testing.T) { transFn := func(q *Dense) bool { - we, _ := willerr(q, nonComplexNumberTypes, nil) + we, _ := willerr(q, dtype.NonComplexNumber, nilTC) _, ok := q.Engine().(Gter) we = we || !ok - if err := typeclassCheck(q.Dtype(), nonComplexNumberTypes); err != nil { + if err := dtype.TypeClassCheck(q.Dtype(), dtype.NonComplexNumber); err != nil { return true // we exit early if the generated type is not something we can handle } r := newRand() @@ -422,17 +424,17 @@ func TestGt_assame(t *testing.T) { return true } if err := quick.Check(transFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Transitivity test for Gt failed: %v", err) + t.Errorf("Transitivity test for Gt failed: %v", err) } } func TestGte_assame(t *testing.T) { transFn := func(q *Dense) bool { - we, _ := willerr(q, nonComplexNumberTypes, nil) + we, _ := willerr(q, dtype.NonComplexNumber, nilTC) _, ok := q.Engine().(Gteer) we = we || !ok - if err := typeclassCheck(q.Dtype(), nonComplexNumberTypes); err != nil { + if err := dtype.TypeClassCheck(q.Dtype(), dtype.NonComplexNumber); err != nil { return true // we exit early if the generated type is not something we can handle } r := newRand() @@ -482,17 +484,17 @@ func TestGte_assame(t *testing.T) { return true } if err := quick.Check(transFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Transitivity test for Gte failed: %v", err) + t.Errorf("Transitivity test for Gte failed: %v", err) } } func TestLt_assame(t *testing.T) { transFn := func(q *Dense) bool { - we, _ := willerr(q, nonComplexNumberTypes, nil) + we, _ := willerr(q, dtype.NonComplexNumber, nilTC) _, ok := q.Engine().(Lter) we = we || !ok - if err := typeclassCheck(q.Dtype(), nonComplexNumberTypes); err != nil { + if err := dtype.TypeClassCheck(q.Dtype(), dtype.NonComplexNumber); err != nil { return true // we exit early if the generated type is not something we can handle } r := newRand() @@ -542,17 +544,17 @@ func TestLt_assame(t *testing.T) { return true } if err := quick.Check(transFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Transitivity test for Lt failed: %v", err) + t.Errorf("Transitivity test for Lt failed: %v", err) } } func TestLte_assame(t *testing.T) { transFn := func(q *Dense) bool { - we, _ := willerr(q, nonComplexNumberTypes, nil) + we, _ := willerr(q, dtype.NonComplexNumber, nilTC) _, ok := q.Engine().(Lteer) we = we || !ok - if err := typeclassCheck(q.Dtype(), nonComplexNumberTypes); err != nil { + if err := dtype.TypeClassCheck(q.Dtype(), dtype.NonComplexNumber); err != nil { return true // we exit early if the generated type is not something we can handle } r := newRand() @@ -602,17 +604,17 @@ func TestLte_assame(t *testing.T) { return true } if err := quick.Check(transFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Transitivity test for Lte failed: %v", err) + t.Errorf("Transitivity test for Lte failed: %v", err) } } func TestEq_assame(t *testing.T) { transFn := func(q *Dense) bool { - we, _ := willerr(q, nonComplexNumberTypes, nil) + we, _ := willerr(q, dtype.NonComplexNumber, nilTC) _, ok := q.Engine().(ElEqer) we = we || !ok - if err := typeclassCheck(q.Dtype(), nonComplexNumberTypes); err != nil { + if err := dtype.TypeClassCheck(q.Dtype(), dtype.NonComplexNumber); err != nil { return true // we exit early if the generated type is not something we can handle } r := newRand() @@ -662,15 +664,15 @@ func TestEq_assame(t *testing.T) { return true } if err := quick.Check(transFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Transitivity test for ElEq failed: %v", err) + t.Errorf("Transitivity test for ElEq failed: %v", err) } symFn := func(q *Dense) bool { - we, _ := willerr(q, nonComplexNumberTypes, nil) + we, _ := willerr(q, dtype.NonComplexNumber, nilTC) _, ok := q.Engine().(ElEqer) we = we || !ok - if err := typeclassCheck(q.Dtype(), nonComplexNumberTypes); err != nil { + if err := dtype.TypeClassCheck(q.Dtype(), dtype.NonComplexNumber); err != nil { return true // we exit early if the generated type is not something we can handle } r := newRand() @@ -699,16 +701,16 @@ func TestEq_assame(t *testing.T) { } if err := quick.Check(symFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Transitivity test for ElEq failed: %v", err) + t.Errorf("Transitivity test for ElEq failed: %v", err) } } func TestNe_assame(t *testing.T) { symFn := func(q *Dense) bool { - we, _ := willerr(q, nonComplexNumberTypes, nil) + we, _ := willerr(q, dtype.NonComplexNumber, nilTC) _, ok := q.Engine().(ElEqer) we = we || !ok - if err := typeclassCheck(q.Dtype(), nonComplexNumberTypes); err != nil { + if err := dtype.TypeClassCheck(q.Dtype(), dtype.NonComplexNumber); err != nil { return true // we exit early if the generated type is not something we can handle } r := newRand() @@ -737,12 +739,12 @@ func TestNe_assame(t *testing.T) { } if err := quick.Check(symFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Transitivity test for ElNe failed: %v", err) + t.Errorf("Transitivity test for ElNe failed: %v", err) } } func TestGtScalar(t *testing.T) { transFn := func(q *Dense) bool { - we, _ := willerr(q, ordTypes, nil) + we, _ := willerr(q, dtype.Ord, nilTC) _, ok := q.Engine().(Gter) we = we || !ok @@ -792,13 +794,13 @@ func TestGtScalar(t *testing.T) { return true } if err := quick.Check(transFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Transitivity test for Gt failed: %v", err) + t.Errorf("Transitivity test for Gt failed: %v", err) } } func TestGteScalar(t *testing.T) { transFn := func(q *Dense) bool { - we, _ := willerr(q, ordTypes, nil) + we, _ := willerr(q, dtype.Ord, nilTC) _, ok := q.Engine().(Gteer) we = we || !ok @@ -848,13 +850,13 @@ func TestGteScalar(t *testing.T) { return true } if err := quick.Check(transFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Transitivity test for Gte failed: %v", err) + t.Errorf("Transitivity test for Gte failed: %v", err) } } func TestLtScalar(t *testing.T) { transFn := func(q *Dense) bool { - we, _ := willerr(q, ordTypes, nil) + we, _ := willerr(q, dtype.Ord, nilTC) _, ok := q.Engine().(Lter) we = we || !ok @@ -904,13 +906,13 @@ func TestLtScalar(t *testing.T) { return true } if err := quick.Check(transFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Transitivity test for Lt failed: %v", err) + t.Errorf("Transitivity test for Lt failed: %v", err) } } func TestLteScalar(t *testing.T) { transFn := func(q *Dense) bool { - we, _ := willerr(q, ordTypes, nil) + we, _ := willerr(q, dtype.Ord, nilTC) _, ok := q.Engine().(Lteer) we = we || !ok @@ -960,13 +962,13 @@ func TestLteScalar(t *testing.T) { return true } if err := quick.Check(transFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Transitivity test for Lte failed: %v", err) + t.Errorf("Transitivity test for Lte failed: %v", err) } } func TestEqScalar(t *testing.T) { transFn := func(q *Dense) bool { - we, _ := willerr(q, eqTypes, nil) + we, _ := willerr(q, dtype.Eq, nilTC) _, ok := q.Engine().(ElEqer) we = we || !ok @@ -1016,11 +1018,11 @@ func TestEqScalar(t *testing.T) { return true } if err := quick.Check(transFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Transitivity test for ElEq failed: %v", err) + t.Errorf("Transitivity test for ElEq failed: %v", err) } symFn := func(q *Dense) bool { - we, _ := willerr(q, eqTypes, nil) + we, _ := willerr(q, dtype.Eq, nilTC) _, ok := q.Engine().(ElEqer) we = we || !ok @@ -1048,12 +1050,12 @@ func TestEqScalar(t *testing.T) { } if err := quick.Check(symFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Symmetry test for ElEq failed: %v", err) + t.Errorf("Symmetry test for ElEq failed: %v", err) } } func TestNeScalar(t *testing.T) { symFn := func(q *Dense) bool { - we, _ := willerr(q, eqTypes, nil) + we, _ := willerr(q, dtype.Eq, nilTC) _, ok := q.Engine().(ElEqer) we = we || !ok @@ -1081,16 +1083,16 @@ func TestNeScalar(t *testing.T) { } if err := quick.Check(symFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Symmetry test for ElNe failed: %v", err) + t.Errorf("Symmetry test for ElNe failed: %v", err) } } func TestGtScalar_assame(t *testing.T) { transFn := func(q *Dense) bool { - we, _ := willerr(q, nonComplexNumberTypes, nil) + we, _ := willerr(q, dtype.NonComplexNumber, nilTC) _, ok := q.Engine().(Gter) we = we || !ok - if err := typeclassCheck(q.Dtype(), nonComplexNumberTypes); err != nil { + if err := dtype.TypeClassCheck(q.Dtype(), dtype.NonComplexNumber); err != nil { return true // we exit early if the generated type is not something we can handle } r := newRand() @@ -1138,17 +1140,17 @@ func TestGtScalar_assame(t *testing.T) { return true } if err := quick.Check(transFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Transitivity test for Gt failed: %v", err) + t.Errorf("Transitivity test for Gt failed: %v", err) } } func TestGteScalar_assame(t *testing.T) { transFn := func(q *Dense) bool { - we, _ := willerr(q, nonComplexNumberTypes, nil) + we, _ := willerr(q, dtype.NonComplexNumber, nilTC) _, ok := q.Engine().(Gteer) we = we || !ok - if err := typeclassCheck(q.Dtype(), nonComplexNumberTypes); err != nil { + if err := dtype.TypeClassCheck(q.Dtype(), dtype.NonComplexNumber); err != nil { return true // we exit early if the generated type is not something we can handle } r := newRand() @@ -1196,17 +1198,17 @@ func TestGteScalar_assame(t *testing.T) { return true } if err := quick.Check(transFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Transitivity test for Gte failed: %v", err) + t.Errorf("Transitivity test for Gte failed: %v", err) } } func TestLtScalar_assame(t *testing.T) { transFn := func(q *Dense) bool { - we, _ := willerr(q, nonComplexNumberTypes, nil) + we, _ := willerr(q, dtype.NonComplexNumber, nilTC) _, ok := q.Engine().(Lter) we = we || !ok - if err := typeclassCheck(q.Dtype(), nonComplexNumberTypes); err != nil { + if err := dtype.TypeClassCheck(q.Dtype(), dtype.NonComplexNumber); err != nil { return true // we exit early if the generated type is not something we can handle } r := newRand() @@ -1254,17 +1256,17 @@ func TestLtScalar_assame(t *testing.T) { return true } if err := quick.Check(transFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Transitivity test for Lt failed: %v", err) + t.Errorf("Transitivity test for Lt failed: %v", err) } } func TestLteScalar_assame(t *testing.T) { transFn := func(q *Dense) bool { - we, _ := willerr(q, nonComplexNumberTypes, nil) + we, _ := willerr(q, dtype.NonComplexNumber, nilTC) _, ok := q.Engine().(Lteer) we = we || !ok - if err := typeclassCheck(q.Dtype(), nonComplexNumberTypes); err != nil { + if err := dtype.TypeClassCheck(q.Dtype(), dtype.NonComplexNumber); err != nil { return true // we exit early if the generated type is not something we can handle } r := newRand() @@ -1312,17 +1314,17 @@ func TestLteScalar_assame(t *testing.T) { return true } if err := quick.Check(transFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Transitivity test for Lte failed: %v", err) + t.Errorf("Transitivity test for Lte failed: %v", err) } } func TestEqScalar_assame(t *testing.T) { transFn := func(q *Dense) bool { - we, _ := willerr(q, nonComplexNumberTypes, nil) + we, _ := willerr(q, dtype.NonComplexNumber, nilTC) _, ok := q.Engine().(ElEqer) we = we || !ok - if err := typeclassCheck(q.Dtype(), nonComplexNumberTypes); err != nil { + if err := dtype.TypeClassCheck(q.Dtype(), dtype.NonComplexNumber); err != nil { return true // we exit early if the generated type is not something we can handle } r := newRand() @@ -1370,15 +1372,15 @@ func TestEqScalar_assame(t *testing.T) { return true } if err := quick.Check(transFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Transitivity test for ElEq failed: %v", err) + t.Errorf("Transitivity test for ElEq failed: %v", err) } symFn := func(q *Dense) bool { - we, _ := willerr(q, nonComplexNumberTypes, nil) + we, _ := willerr(q, dtype.NonComplexNumber, nilTC) _, ok := q.Engine().(ElEqer) we = we || !ok - if err := typeclassCheck(q.Dtype(), nonComplexNumberTypes); err != nil { + if err := dtype.TypeClassCheck(q.Dtype(), dtype.NonComplexNumber); err != nil { return true // we exit early if the generated type is not something we can handle } r := newRand() @@ -1405,16 +1407,16 @@ func TestEqScalar_assame(t *testing.T) { } if err := quick.Check(symFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Symmetry test for ElEq failed: %v", err) + t.Errorf("Symmetry test for ElEq failed: %v", err) } } func TestNeScalar_assame(t *testing.T) { symFn := func(q *Dense) bool { - we, _ := willerr(q, nonComplexNumberTypes, nil) + we, _ := willerr(q, dtype.NonComplexNumber, nilTC) _, ok := q.Engine().(ElEqer) we = we || !ok - if err := typeclassCheck(q.Dtype(), nonComplexNumberTypes); err != nil { + if err := dtype.TypeClassCheck(q.Dtype(), dtype.NonComplexNumber); err != nil { return true // we exit early if the generated type is not something we can handle } r := newRand() @@ -1441,6 +1443,6 @@ func TestNeScalar_assame(t *testing.T) { } if err := quick.Check(symFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Symmetry test for ElNe failed: %v", err) + t.Errorf("Symmetry test for ElNe failed: %v", err) } } diff --git a/api_cmp_test.go b/api_cmp_test.go new file mode 100644 index 0000000..9e785d7 --- /dev/null +++ b/api_cmp_test.go @@ -0,0 +1,239 @@ +package tensor + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +// This file contains the tests for API functions that aren't generated by genlib + +func TestLtScalarScalar(t *testing.T) { + // scalar-scalar + a := New(WithBacking([]float64{6})) + b := New(WithBacking([]float64{2})) + var correct interface{} = false + + res, err := Lt(a, b) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) + + // scalar-tensor + a = New(WithBacking([]float64{1, 4})) + b = New(WithBacking([]float64{2})) + correct = []bool{true, false} + + res, err = Lt(a, b) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) + + // tensor-scalar + a = New(WithBacking([]float64{3})) + b = New(WithBacking([]float64{6, 2})) + correct = []bool{true, false} + + res, err = Lt(a, b) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) + + // tensor - tensor + a = New(WithBacking([]float64{21, 2})) + b = New(WithBacking([]float64{7, 10})) + correct = []bool{false, true} + + res, err = Lt(a, b) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) +} + +func TestGtScalarScalar(t *testing.T) { + // scalar-scalar + a := New(WithBacking([]float64{6})) + b := New(WithBacking([]float64{2})) + var correct interface{} = true + + res, err := Gt(a, b) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) + + // scalar-tensor + a = New(WithBacking([]float64{1, 4})) + b = New(WithBacking([]float64{2})) + correct = []bool{false, true} + + res, err = Gt(a, b) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) + + // tensor-scalar + a = New(WithBacking([]float64{3})) + b = New(WithBacking([]float64{6, 2})) + correct = []bool{false, true} + + res, err = Gt(a, b) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) + + // tensor - tensor + a = New(WithBacking([]float64{21, 2})) + b = New(WithBacking([]float64{7, 10})) + correct = []bool{true, false} + + res, err = Gt(a, b) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) +} + +func TestLteScalarScalar(t *testing.T) { + // scalar-scalar + a := New(WithBacking([]float64{6})) + b := New(WithBacking([]float64{2})) + var correct interface{} = false + + res, err := Lte(a, b) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) + + // scalar-tensor + a = New(WithBacking([]float64{1, 2, 4})) + b = New(WithBacking([]float64{2})) + correct = []bool{true, true, false} + + res, err = Lte(a, b) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) + + // tensor-scalar + a = New(WithBacking([]float64{3})) + b = New(WithBacking([]float64{6, 2})) + correct = []bool{true, false} + + res, err = Lte(a, b) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) + + // tensor - tensor + a = New(WithBacking([]float64{21, 2})) + b = New(WithBacking([]float64{7, 10})) + correct = []bool{false, true} + + res, err = Lte(a, b) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) +} + +func TestGteScalarScalar(t *testing.T) { + // scalar-scalar + a := New(WithBacking([]float64{6})) + b := New(WithBacking([]float64{2})) + var correct interface{} = true + + res, err := Gte(a, b) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) + + // scalar-tensor + a = New(WithBacking([]float64{1, 2, 4})) + b = New(WithBacking([]float64{2})) + correct = []bool{false, true, true} + + res, err = Gte(a, b) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) + + // tensor-scalar + a = New(WithBacking([]float64{3})) + b = New(WithBacking([]float64{6, 3, 2})) + correct = []bool{false, true, true} + + res, err = Gte(a, b) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) + + // tensor - tensor + a = New(WithBacking([]float64{21, 31, 2})) + b = New(WithBacking([]float64{7, 31, 10})) + correct = []bool{true, true, false} + + res, err = Gte(a, b) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) +} + +func TestElEqScalarScalar(t *testing.T) { + // scalar-scalar + a := New(WithBacking([]float64{6})) + b := New(WithBacking([]float64{2})) + var correct interface{} = false + + res, err := ElEq(a, b) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) + + // scalar-tensor + a = New(WithBacking([]float64{1, 2, 4})) + b = New(WithBacking([]float64{2})) + correct = []bool{false, true, false} + + res, err = ElEq(a, b) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) + + // tensor-scalar + a = New(WithBacking([]float64{3})) + b = New(WithBacking([]float64{6, 3, 2})) + correct = []bool{false, true, false} + + res, err = ElEq(a, b) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) + + // tensor - tensor + a = New(WithBacking([]float64{21, 10})) + b = New(WithBacking([]float64{7, 10})) + correct = []bool{false, true} + + res, err = ElEq(a, b) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) +} diff --git a/api_matop.go b/api_matop.go index 0db34e7..4d98479 100644 --- a/api_matop.go +++ b/api_matop.go @@ -1,14 +1,38 @@ package tensor -import "github.com/pkg/errors" +import ( + "github.com/pkg/errors" +) // this file handles matops. While by default most of these matops should already have been defined as part of the // Tensor interface, not all are possible(for example, concatenating a sparse tensor), hence the need for the following functions +// Narrow narrows the tensor. +func Narrow(t Tensor, dim, start, length int) (View, error) { + dim = resolveAxis(dim, t.Dims()) + + slices := make([]Slice, MinInt(dim+1, t.Dims())) + slices[dim] = S(start, start+length, 1) + + return t.Slice(slices...) +} + // Repeat repeats a Tensor along the axis and given the number of repeats. func Repeat(t Tensor, axis int, repeats ...int) (retVal Tensor, err error) { - if r, ok := t.Engine().(Repeater); ok { - return r.Repeat(t, axis, repeats...) + e := t.Engine() + ctx := ctxFromEngine(e) + if r, ok := e.(Repeater); ok { + return r.Repeat(ctx, t, axis, repeats...) + } + return nil, errors.New("Engine does not support Repeat") +} + +// RepeatReuse repeats a Tensor along the axis and the given number of repeats, and puts the results in the provided reuse tensor. If the reuse tensor is not correctly sized, then an error will be given, but the results will still be valid. +func RepeatReuse(t, reuse Tensor, axis int, repeats ...int) (retval Tensor, err error) { + e := t.Engine() + ctx := ctxFromEngine(e) + if r, ok := e.(Repeater); ok { + return r.RepeatReuse(ctx, t, reuse, axis, repeats...) } return nil, errors.New("Engine does not support Repeat") } @@ -18,6 +42,14 @@ func T(t Tensor, axes ...int) (retVal Tensor, err error) { switch tt := t.(type) { case *Dense: return tt.SafeT(axes...) + case DenseView: + var ret *Dense + if ret, err = tt.SafeT(axes...); err != nil { + return nil, errors.Wrap(err, "T() off a DenseView") + } + return DenseView{ret}, nil + default: + return nil, nyierr(typeNYI, t) } panic("Unreachable") } @@ -28,11 +60,20 @@ func Transpose(t Tensor, axes ...int) (retVal Tensor, err error) { case *Dense: var ret *Dense if ret, err = tt.SafeT(axes...); err != nil { - return + return nil, errors.Wrap(err, "Unable to perform .SafeT() on a *Dense") } ret.Transpose() retVal = ret return + case DenseView: + var ret *Dense + if ret, err = tt.SafeT(axes...); err != nil { + return nil, errors.Wrap(err, "Unable to perform .SafeT() on a DenseView") + } + ret.Transpose() + return DenseView{ret}, nil + default: + return nil, nyierr(typeNYI, t) } panic("Unreachable") } @@ -45,15 +86,30 @@ func Concat(axis int, t Tensor, others ...Tensor) (retVal Tensor, err error) { } switch T := t.(type) { case *Dense: + // IF YOU UPDATE THIS, UPDATE THE DENSE VIEW CASE TOO. + ts := make([]*Dense, len(others)) + for i, o := range others { + ot, err := assertDense(o) + if err == nil { + ts[i] = ot + continue + } + return nil, errors.Wrapf(err, "Expected all Tensors to be *Dense. Got %T instead", o) + } + return T.Concat(axis, ts...) + case DenseView: ts := make([]*Dense, len(others)) for i, o := range others { - if ot, ok := o.(*Dense); ok { + ot, err := assertDense(o) + if err == nil { ts[i] = ot continue } - return nil, errors.Errorf("Expected all Tensors to be *Dense") + return nil, errors.Wrapf(err, "Expected all Tensors to be *Dense. Got %T instead", o) } return T.Concat(axis, ts...) + default: + return nil, nyierr(typeNYI, t) } panic("Unreachable") } @@ -76,7 +132,7 @@ func Copy(dst, src Tensor) error { copyDense(dt, st) return nil default: - return errors.Errorf("NYI for Copy %T", src) + return nyierr(typeNYI, src) } panic("Unreachable") } @@ -108,3 +164,78 @@ func Materialize(t Tensor) Tensor { return t } } + +func Diag(t Tensor) (retVal Tensor, err error) { + e := t.Engine() + ctx := ctxFromEngine(e) + if d, ok := t.Engine().(Diager); ok { + return d.Diag(ctx, t) + } + return nil, errors.Errorf("Unable to perform diagonalization of tensor ") +} + +// ByIndices allows for selection of value of `a` byt the indices listed in the `indices` tensor. +// The `indices` tensor has to be a vector-like tensor of ints. +func ByIndices(a, indices Tensor, axis int, opts ...FuncOpt) (retVal Tensor, err error) { + if axis >= a.Shape().Dims() { + return nil, errors.Errorf("Cannot select by indices on axis %d. Input only has %d dims", axis, a.Shape().Dims()) + } + if sbi, ok := a.Engine().(ByIndiceser); ok { + return sbi.SelectByIndices(a, indices, axis, opts...) + } + return nil, errors.Errorf("Unable to select by indices. Engine %T does not support that.", a.Engine()) +} + +// ByIndicesB is the backpropagation of ByIndices. +func ByIndicesB(a, b, indices Tensor, axis int, opts ...FuncOpt) (retVal Tensor, err error) { + if axis >= a.Shape().Dims() { + return nil, errors.Errorf("Cannot select by indices on axis %d. Input only has %d dims", axis, a.Shape().Dims()) + } + if sbi, ok := a.Engine().(ByIndiceser); ok { + return sbi.SelectByIndicesB(a, b, indices, axis, opts...) + } + return nil, errors.Errorf("Unable to select by indices. Engine %T does not support that.", a.Engine()) +} + +// LogSoftMax applies log softmax to the given tensor. +func LogSoftMax(x Tensor, axis int, opts ...FuncOpt) (retVal Tensor, err error) { + if sm, ok := x.Engine().(SoftMaxer); ok { + return sm.LogSoftMax(x, axis, opts...) + } + + return nil, errors.Errorf("Unable to apply LogSoftMax. Engine %T does not support that.", x.Engine()) +} + +// SoftMax applies softmax to the given tensor. +func SoftMax(x Tensor, axis int, opts ...FuncOpt) (retVal Tensor, err error) { + if sm, ok := x.Engine().(SoftMaxer); ok { + return sm.SoftMax(x, axis, opts...) + } + + return nil, errors.Errorf("Unable to apply SoftMax. Engine %T does not support that.", x.Engine()) +} + +// SoftMaxB applies softmax backwards operation +func SoftMaxB(output, grad Tensor, axis int, opts ...FuncOpt) (retVal Tensor, err error) { + if sm, ok := output.Engine().(SoftMaxer); ok { + return sm.SoftMaxB(output, grad, axis, opts...) + } + + return nil, errors.Errorf("Unable to apply SoftMaxB. Engine %T does not support that.", output.Engine()) +} + +// LogSoftMaxB applies softmax backwards operation +func LogSoftMaxB(output, grad Tensor, axis int, opts ...FuncOpt) (retVal Tensor, err error) { + if sm, ok := output.Engine().(SoftMaxer); ok { + return sm.LogSoftMaxB(output, grad, axis, opts...) + } + + return nil, errors.Errorf("Unable to apply SoftMaxB. Engine %T does not support that.", output.Engine()) +} + +func Scatter(a, indices Tensor, opts ...FuncOpt) (retVal Tensor, err error) { + if sc, ok := a.Engine().(Scatterer); ok { + return sc.Scatter(a, indices, opts...) + } + return nil, errors.Errorf("Unable to scatter. Engine %T does not support Scattering.", a.Engine()) +} diff --git a/api_minmax.go b/api_minmax.go new file mode 100644 index 0000000..e8a7de1 --- /dev/null +++ b/api_minmax.go @@ -0,0 +1,155 @@ +package tensor + +import "github.com/pkg/errors" + +func MinBetween(a, b interface{}, opts ...FuncOpt) (retVal Tensor, err error) { + var minbetweener MinBetweener + var oe StandardEngine + var ok bool + switch at := a.(type) { + case Tensor: + oe, _ = at.Engine().(StandardEngine) + switch bt := b.(type) { + case Tensor: + if !bt.Shape().IsScalar() && !at.Shape().IsScalar() { // non-scalar Tensor addition + if oe != nil { + return oe.MinBetween(at, bt, opts...) + } + if oe, ok = bt.Engine().(StandardEngine); ok { + return oe.MinBetween(at, bt, opts...) + } + if minbetweener, ok = at.Engine().(MinBetweener); ok { + return minbetweener.MinBetween(at, bt, opts...) + } + if minbetweener, ok = bt.Engine().(MinBetweener); ok { + return minbetweener.MinBetween(at, bt, opts...) + } + return nil, errors.New("Neither engines of either operand support MinBetween") + + } else { // at least one of the operands is a scalar + var leftTensor bool + if !bt.Shape().IsScalar() { + leftTensor = false // a Scalar-Tensor * b Tensor + tmp := at + at = bt + bt = tmp + } else { + leftTensor = true // a Tensor * b Scalar-Tensor + } + + if oe != nil { + return oe.MinBetweenScalar(at, bt, leftTensor, opts...) + } + if oe, ok = bt.Engine().(StandardEngine); ok { + return oe.MinBetweenScalar(at, bt, leftTensor, opts...) + } + if minbetweener, ok = at.Engine().(MinBetweener); ok { + return minbetweener.MinBetweenScalar(at, bt, leftTensor, opts...) + } + if minbetweener, ok = bt.Engine().(MinBetweener); ok { + return minbetweener.MinBetweenScalar(at, bt, leftTensor, opts...) + } + return nil, errors.New("Neither engines of either operand support MinBetween") + } + + default: + if oe != nil { + return oe.MinBetweenScalar(at, bt, true, opts...) + } + if minbetweener, ok = at.Engine().(MinBetweener); ok { + return minbetweener.MinBetweenScalar(at, bt, true, opts...) + } + return nil, errors.New("Operand A's engine does not support MinBetween") + } + default: + switch bt := b.(type) { + case Tensor: + if oe, ok = bt.Engine().(StandardEngine); ok { + return oe.MinBetweenScalar(bt, at, false, opts...) + } + if minbetweener, ok = bt.Engine().(MinBetweener); ok { + return minbetweener.MinBetweenScalar(bt, at, false, opts...) + } + return nil, errors.New("Operand B's engine does not support MinBetween") + default: + return nil, errors.Errorf("Cannot perform MinBetween of %T and %T", a, b) + } + } + panic("Unreachable") +} + +func MaxBetween(a, b interface{}, opts ...FuncOpt) (retVal Tensor, err error) { + var maxbetweener MaxBetweener + var oe StandardEngine + var ok bool + switch at := a.(type) { + case Tensor: + oe, _ = at.Engine().(StandardEngine) + switch bt := b.(type) { + case Tensor: + if !bt.Shape().IsScalar() && !at.Shape().IsScalar() { // non-scalar Tensor addition + if oe != nil { + return oe.MaxBetween(at, bt, opts...) + } + if oe, ok = bt.Engine().(StandardEngine); ok { + return oe.MaxBetween(at, bt, opts...) + } + if maxbetweener, ok = at.Engine().(MaxBetweener); ok { + return maxbetweener.MaxBetween(at, bt, opts...) + } + if maxbetweener, ok = bt.Engine().(MaxBetweener); ok { + return maxbetweener.MaxBetween(at, bt, opts...) + } + return nil, errors.New("Neither engines of either operand support MaxBetween") + + } else { // at least one of the operands is a scalar + var leftTensor bool + if !bt.Shape().IsScalar() { + leftTensor = false // a Scalar-Tensor * b Tensor + tmp := at + at = bt + bt = tmp + } else { + leftTensor = true // a Tensor * b Scalar-Tensor + } + + if oe != nil { + return oe.MaxBetweenScalar(at, bt, leftTensor, opts...) + } + if oe, ok = bt.Engine().(StandardEngine); ok { + return oe.MaxBetweenScalar(at, bt, leftTensor, opts...) + } + if maxbetweener, ok = at.Engine().(MaxBetweener); ok { + return maxbetweener.MaxBetweenScalar(at, bt, leftTensor, opts...) + } + if maxbetweener, ok = bt.Engine().(MaxBetweener); ok { + return maxbetweener.MaxBetweenScalar(at, bt, leftTensor, opts...) + } + return nil, errors.New("Neither engines of either operand support MaxBetween") + } + + default: + if oe != nil { + return oe.MaxBetweenScalar(at, bt, true, opts...) + } + if maxbetweener, ok = at.Engine().(MaxBetweener); ok { + return maxbetweener.MaxBetweenScalar(at, bt, true, opts...) + } + return nil, errors.New("Operand A's engine does not support MaxBetween") + } + default: + switch bt := b.(type) { + case Tensor: + if oe, ok = bt.Engine().(StandardEngine); ok { + return oe.MaxBetweenScalar(bt, at, false, opts...) + } + if maxbetweener, ok = bt.Engine().(MaxBetweener); ok { + return maxbetweener.MaxBetweenScalar(bt, at, false, opts...) + } + return nil, errors.New("Operand B's engine does not support MaxBetween") + default: + return nil, errors.Errorf("Cannot perform MaxBetween of %T and %T", a, b) + } + } + panic("Unreachable") +} diff --git a/api_reduction.go b/api_reduction.go index f146972..63c2257 100644 --- a/api_reduction.go +++ b/api_reduction.go @@ -2,26 +2,52 @@ package tensor import "github.com/pkg/errors" -// Sum sums a Tensor along the given axes +// Sum sums a Tensor along the given axes. func Sum(t Tensor, along ...int) (retVal Tensor, err error) { - if sumer, ok := t.Engine().(Sumer); ok { - return sumer.Sum(t, along...) + e := t.Engine() + ctx := ctxFromEngine(e) + if sumer, ok := e.(Sumer); ok { + return sumer.Sum(ctx, t, along...) } return nil, errors.New("Engine does not support Sum()") } +// Prod sums a Tensor along the given axes. +func Prod(t Tensor, along ...int) (retVal Tensor, err error) { + e := t.Engine() + ctx := ctxFromEngine(e) + if sumer, ok := e.(Proder); ok { + return sumer.Prod(ctx, t, along...) + } + return nil, errors.New("Engine does not support Prod()") +} + +// Max finds the maximum value along the given axes. +func Max(t Tensor, along ...int) (retVal Tensor, err error) { + e := t.Engine() + ctx := ctxFromEngine(e) + if maxer, ok := e.(Maxer); ok { + return maxer.Max(ctx, t, along...) + } + return nil, errors.New("Engine does not support Max()") +} + // Argmax finds the index of the max value along the axis provided func Argmax(t Tensor, axis int) (retVal Tensor, err error) { - if argmaxer, ok := t.Engine().(Argmaxer); ok { - return argmaxer.Argmax(t, axis) + e := t.Engine() + ctx := ctxFromEngine(e) + if argmaxer, ok := e.(Argmaxer); ok { + return argmaxer.Argmax(ctx, t, axis) } return nil, errors.New("Engine does not support Argmax()") } // Argmin finds the index of the min value along the axis provided func Argmin(t Tensor, axis int) (retVal Tensor, err error) { - if argminer, ok := t.Engine().(Argminer); ok { - return argminer.Argmin(t, axis) + e := t.Engine() + ctx := ctxFromEngine(e) + if argminer, ok := e.(Argminer); ok { + return argminer.Argmin(ctx, t, axis) } return nil, errors.New("Engine does not support Argmax()") } diff --git a/api_unary.go b/api_unary.go index b1afe71..4c81e33 100644 --- a/api_unary.go +++ b/api_unary.go @@ -1,9 +1,9 @@ -// Code generated by genlib2. DO NOT EDIT. - package tensor import "github.com/pkg/errors" +// Code generated by genlib2. DO NOT EDIT. + func Neg(a Tensor, opts ...FuncOpt) (retVal Tensor, err error) { e := a.Engine() if neger, ok := e.(Neger); ok { diff --git a/api_unary_generated_test.go b/api_unary_generated_test.go index 31a23f2..64813ae 100644 --- a/api_unary_generated_test.go +++ b/api_unary_generated_test.go @@ -1,17 +1,19 @@ -// Code generated by genlib2. DO NOT EDIT. - package tensor import ( "testing" "testing/quick" + + "gorgonia.org/dtype" ) +// Code generated by genlib2. DO NOT EDIT. + func TestNeg(t *testing.T) { invFn := func(q *Dense) bool { a := q.Clone().(*Dense) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Neger) we = we || !ok @@ -37,7 +39,7 @@ func TestSquare(t *testing.T) { invFn := func(q *Dense) bool { a := q.Clone().(*Dense) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Squarer) we = we || !ok @@ -48,7 +50,7 @@ func TestSquare(t *testing.T) { } return true } - if err := typeclassCheck(a.Dtype(), floatcmplxTypes); err != nil { + if err := dtype.TypeClassCheck(a.Dtype(), dtype.FloatComplex); err != nil { return true // uninvertible due to type class implementation issues } Sqrt(ret, UseUnsafe()) @@ -66,7 +68,7 @@ func TestCube(t *testing.T) { invFn := func(q *Dense) bool { a := q.Clone().(*Dense) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Cuber) we = we || !ok @@ -77,7 +79,7 @@ func TestCube(t *testing.T) { } return true } - if err := typeclassCheck(a.Dtype(), floatTypes); err != nil { + if err := dtype.TypeClassCheck(a.Dtype(), dtype.Floats); err != nil { return true // uninvertible due to type class implementation issues } Cbrt(ret, UseUnsafe()) @@ -95,7 +97,7 @@ func TestExp(t *testing.T) { invFn := func(q *Dense) bool { a := q.Clone().(*Dense) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, floatcmplxTypes, nil) + we, willFailEq := willerr(a, dtype.FloatComplex, nilTC) _, ok := q.Engine().(Exper) we = we || !ok @@ -121,7 +123,7 @@ func TestLog(t *testing.T) { invFn := func(q *Dense) bool { a := q.Clone().(*Dense) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, floatcmplxTypes, nil) + we, willFailEq := willerr(a, dtype.FloatComplex, nilTC) _, ok := q.Engine().(Loger) we = we || !ok @@ -147,7 +149,7 @@ func TestSqrt(t *testing.T) { invFn := func(q *Dense) bool { a := q.Clone().(*Dense) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, floatcmplxTypes, nil) + we, willFailEq := willerr(a, dtype.FloatComplex, nilTC) _, ok := q.Engine().(Sqrter) we = we || !ok @@ -173,7 +175,7 @@ func TestCbrt(t *testing.T) { invFn := func(q *Dense) bool { a := q.Clone().(*Dense) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, floatTypes, nil) + we, willFailEq := willerr(a, dtype.Floats, nilTC) _, ok := q.Engine().(Cbrter) we = we || !ok @@ -199,7 +201,7 @@ func TestNeg_unsafe(t *testing.T) { invFn := func(q *Dense) bool { a := q.Clone().(*Dense) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Neger) we = we || !ok @@ -218,7 +220,6 @@ func TestNeg_unsafe(t *testing.T) { t.Errorf("Expected ret to be the same as a") return false } - return true } @@ -230,7 +231,7 @@ func TestSquare_unsafe(t *testing.T) { invFn := func(q *Dense) bool { a := q.Clone().(*Dense) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Squarer) we = we || !ok @@ -241,7 +242,7 @@ func TestSquare_unsafe(t *testing.T) { } return true } - if err := typeclassCheck(a.Dtype(), floatcmplxTypes); err != nil { + if err := dtype.TypeClassCheck(a.Dtype(), dtype.FloatComplex); err != nil { return true // uninvertible due to type class implementation issues } Sqrt(ret, UseUnsafe()) @@ -252,7 +253,6 @@ func TestSquare_unsafe(t *testing.T) { t.Errorf("Expected ret to be the same as a") return false } - return true } @@ -264,7 +264,7 @@ func TestCube_unsafe(t *testing.T) { invFn := func(q *Dense) bool { a := q.Clone().(*Dense) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Cuber) we = we || !ok @@ -275,7 +275,7 @@ func TestCube_unsafe(t *testing.T) { } return true } - if err := typeclassCheck(a.Dtype(), floatTypes); err != nil { + if err := dtype.TypeClassCheck(a.Dtype(), dtype.Floats); err != nil { return true // uninvertible due to type class implementation issues } Cbrt(ret, UseUnsafe()) @@ -286,7 +286,6 @@ func TestCube_unsafe(t *testing.T) { t.Errorf("Expected ret to be the same as a") return false } - return true } @@ -298,7 +297,7 @@ func TestExp_unsafe(t *testing.T) { invFn := func(q *Dense) bool { a := q.Clone().(*Dense) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, floatcmplxTypes, nil) + we, willFailEq := willerr(a, dtype.FloatComplex, nilTC) _, ok := q.Engine().(Exper) we = we || !ok @@ -317,7 +316,6 @@ func TestExp_unsafe(t *testing.T) { t.Errorf("Expected ret to be the same as a") return false } - return true } @@ -329,7 +327,7 @@ func TestLog_unsafe(t *testing.T) { invFn := func(q *Dense) bool { a := q.Clone().(*Dense) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, floatcmplxTypes, nil) + we, willFailEq := willerr(a, dtype.FloatComplex, nilTC) _, ok := q.Engine().(Loger) we = we || !ok @@ -348,7 +346,6 @@ func TestLog_unsafe(t *testing.T) { t.Errorf("Expected ret to be the same as a") return false } - return true } @@ -360,7 +357,7 @@ func TestSqrt_unsafe(t *testing.T) { invFn := func(q *Dense) bool { a := q.Clone().(*Dense) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, floatcmplxTypes, nil) + we, willFailEq := willerr(a, dtype.FloatComplex, nilTC) _, ok := q.Engine().(Sqrter) we = we || !ok @@ -379,7 +376,6 @@ func TestSqrt_unsafe(t *testing.T) { t.Errorf("Expected ret to be the same as a") return false } - return true } @@ -391,7 +387,7 @@ func TestCbrt_unsafe(t *testing.T) { invFn := func(q *Dense) bool { a := q.Clone().(*Dense) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, floatTypes, nil) + we, willFailEq := willerr(a, dtype.Floats, nilTC) _, ok := q.Engine().(Cbrter) we = we || !ok @@ -410,7 +406,6 @@ func TestCbrt_unsafe(t *testing.T) { t.Errorf("Expected ret to be the same as a") return false } - return true } @@ -423,7 +418,7 @@ func TestNeg_reuse(t *testing.T) { a := q.Clone().(*Dense) reuse := New(Of(a.t), WithShape(a.Shape().Clone()...)) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Neger) we = we || !ok @@ -455,7 +450,7 @@ func TestSquare_reuse(t *testing.T) { a := q.Clone().(*Dense) reuse := New(Of(a.t), WithShape(a.Shape().Clone()...)) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Squarer) we = we || !ok @@ -466,7 +461,7 @@ func TestSquare_reuse(t *testing.T) { } return true } - if err := typeclassCheck(a.Dtype(), floatcmplxTypes); err != nil { + if err := dtype.TypeClassCheck(a.Dtype(), dtype.FloatComplex); err != nil { return true // uninvertible due to type class implementation issues } Sqrt(ret, UseUnsafe()) @@ -490,7 +485,7 @@ func TestCube_reuse(t *testing.T) { a := q.Clone().(*Dense) reuse := New(Of(a.t), WithShape(a.Shape().Clone()...)) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Cuber) we = we || !ok @@ -501,7 +496,7 @@ func TestCube_reuse(t *testing.T) { } return true } - if err := typeclassCheck(a.Dtype(), floatTypes); err != nil { + if err := dtype.TypeClassCheck(a.Dtype(), dtype.Floats); err != nil { return true // uninvertible due to type class implementation issues } Cbrt(ret, UseUnsafe()) @@ -525,7 +520,7 @@ func TestExp_reuse(t *testing.T) { a := q.Clone().(*Dense) reuse := New(Of(a.t), WithShape(a.Shape().Clone()...)) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, floatcmplxTypes, nil) + we, willFailEq := willerr(a, dtype.FloatComplex, nilTC) _, ok := q.Engine().(Exper) we = we || !ok @@ -557,7 +552,7 @@ func TestLog_reuse(t *testing.T) { a := q.Clone().(*Dense) reuse := New(Of(a.t), WithShape(a.Shape().Clone()...)) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, floatcmplxTypes, nil) + we, willFailEq := willerr(a, dtype.FloatComplex, nilTC) _, ok := q.Engine().(Loger) we = we || !ok @@ -589,7 +584,7 @@ func TestSqrt_reuse(t *testing.T) { a := q.Clone().(*Dense) reuse := New(Of(a.t), WithShape(a.Shape().Clone()...)) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, floatcmplxTypes, nil) + we, willFailEq := willerr(a, dtype.FloatComplex, nilTC) _, ok := q.Engine().(Sqrter) we = we || !ok @@ -621,7 +616,7 @@ func TestCbrt_reuse(t *testing.T) { a := q.Clone().(*Dense) reuse := New(Of(a.t), WithShape(a.Shape().Clone()...)) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, floatTypes, nil) + we, willFailEq := willerr(a, dtype.Floats, nilTC) _, ok := q.Engine().(Cbrter) we = we || !ok @@ -655,7 +650,7 @@ func TestNeg_incr(t *testing.T) { correct := a.Clone().(*Dense) incr.Memset(identityVal(100, a.t)) correct.Add(incr, UseUnsafe()) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Neger) we = we || !ok @@ -688,7 +683,7 @@ func TestSquare_incr(t *testing.T) { correct := a.Clone().(*Dense) incr.Memset(identityVal(100, a.t)) correct.Add(incr, UseUnsafe()) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Squarer) we = we || !ok @@ -699,7 +694,7 @@ func TestSquare_incr(t *testing.T) { } return true } - if err := typeclassCheck(a.Dtype(), floatcmplxTypes); err != nil { + if err := dtype.TypeClassCheck(a.Dtype(), dtype.FloatComplex); err != nil { return true // uninvertible due to type class implementation issues } if ret, err = Sub(ret, identityVal(100, a.Dtype()), UseUnsafe()); err != nil { @@ -724,7 +719,7 @@ func TestCube_incr(t *testing.T) { correct := a.Clone().(*Dense) incr.Memset(identityVal(100, a.t)) correct.Add(incr, UseUnsafe()) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Cuber) we = we || !ok @@ -735,7 +730,7 @@ func TestCube_incr(t *testing.T) { } return true } - if err := typeclassCheck(a.Dtype(), floatTypes); err != nil { + if err := dtype.TypeClassCheck(a.Dtype(), dtype.Floats); err != nil { return true // uninvertible due to type class implementation issues } if ret, err = Sub(ret, identityVal(100, a.Dtype()), UseUnsafe()); err != nil { @@ -760,7 +755,7 @@ func TestExp_incr(t *testing.T) { correct := a.Clone().(*Dense) incr.Memset(identityVal(100, a.t)) correct.Add(incr, UseUnsafe()) - we, willFailEq := willerr(a, floatcmplxTypes, nil) + we, willFailEq := willerr(a, dtype.FloatComplex, nilTC) _, ok := q.Engine().(Exper) we = we || !ok @@ -793,7 +788,7 @@ func TestLog_incr(t *testing.T) { correct := a.Clone().(*Dense) incr.Memset(identityVal(100, a.t)) correct.Add(incr, UseUnsafe()) - we, willFailEq := willerr(a, floatcmplxTypes, nil) + we, willFailEq := willerr(a, dtype.FloatComplex, nilTC) _, ok := q.Engine().(Loger) we = we || !ok @@ -826,7 +821,7 @@ func TestSqrt_incr(t *testing.T) { correct := a.Clone().(*Dense) incr.Memset(identityVal(100, a.t)) correct.Add(incr, UseUnsafe()) - we, willFailEq := willerr(a, floatcmplxTypes, nil) + we, willFailEq := willerr(a, dtype.FloatComplex, nilTC) _, ok := q.Engine().(Sqrter) we = we || !ok @@ -859,7 +854,7 @@ func TestCbrt_incr(t *testing.T) { correct := a.Clone().(*Dense) incr.Memset(identityVal(100, a.t)) correct.Add(incr, UseUnsafe()) - we, willFailEq := willerr(a, floatTypes, nil) + we, willFailEq := willerr(a, dtype.Floats, nilTC) _, ok := q.Engine().(Cbrter) we = we || !ok diff --git a/api_unary_test.go b/api_unary_test.go index 9c735e6..25b68f7 100644 --- a/api_unary_test.go +++ b/api_unary_test.go @@ -1,14 +1,15 @@ package tensor import ( + "math" "math/rand" "testing" "testing/quick" "time" - "math" - "github.com/stretchr/testify/assert" "github.com/chewxy/math32" + "github.com/stretchr/testify/assert" + "gorgonia.org/dtype" ) /* @@ -354,12 +355,12 @@ func TestInvSqrt(t *testing.T) { a := q.Clone().(*Dense) b := q.Clone().(*Dense) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, floatTypes, nil) + we, willFailEq := willerr(a, dtype.Floats, nilTC) _, ok := q.Engine().(InvSqrter) we = we || !ok // we'll exclude everything other than floats - if err := typeclassCheck(a.Dtype(), floatTypes); err != nil { + if err := dtype.TypeClassCheck(a.Dtype(), dtype.Floats); err != nil { return true } ret, err := InvSqrt(a) @@ -387,12 +388,12 @@ func TestInvSqrt(t *testing.T) { a := q.Clone().(*Dense) b := q.Clone().(*Dense) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, floatTypes, nil) + we, willFailEq := willerr(a, dtype.Floats, nilTC) _, ok := q.Engine().(InvSqrter) we = we || !ok // we'll exclude everything other than floats - if err := typeclassCheck(a.Dtype(), floatTypes); err != nil { + if err := dtype.TypeClassCheck(a.Dtype(), dtype.Floats); err != nil { return true } ret, err := InvSqrt(a, UseUnsafe()) @@ -426,12 +427,12 @@ func TestInvSqrt(t *testing.T) { reuse := q.Clone().(*Dense) reuse.Zero() correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, floatTypes, nil) + we, willFailEq := willerr(a, dtype.Floats, nilTC) _, ok := q.Engine().(InvSqrter) we = we || !ok // we'll exclude everything other than floats - if err := typeclassCheck(a.Dtype(), floatTypes); err != nil { + if err := dtype.TypeClassCheck(a.Dtype(), dtype.Floats); err != nil { return true } ret, err := InvSqrt(a, WithReuse(reuse)) @@ -466,12 +467,12 @@ func TestInvSqrt(t *testing.T) { incr.Memset(identityVal(100, a.t)) correct.Add(incr, UseUnsafe()) - we, willFailEq := willerr(a, floatTypes, nil) + we, willFailEq := willerr(a, dtype.Floats, nilTC) _, ok := q.Engine().(InvSqrter) we = we || !ok // we'll exclude everything other than floats - if err := typeclassCheck(a.Dtype(), floatTypes); err != nil { + if err := dtype.TypeClassCheck(a.Dtype(), dtype.Floats); err != nil { return true } ret, err := InvSqrt(a, WithIncr(incr)) @@ -509,12 +510,12 @@ func TestInv(t *testing.T) { invFn := func(q *Dense) bool { a := q.Clone().(*Dense) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, floatTypes, nil) + we, willFailEq := willerr(a, dtype.Floats, nilTC) _, ok := q.Engine().(Inver) we = we || !ok // we'll exclude everything other than floats - if err := typeclassCheck(a.Dtype(), floatTypes); err != nil { + if err := dtype.TypeClassCheck(a.Dtype(), dtype.Floats); err != nil { return true } ret, err := Inv(a) @@ -541,12 +542,12 @@ func TestInv(t *testing.T) { a := q.Clone().(*Dense) b := q.Clone().(*Dense) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, floatTypes, nil) + we, willFailEq := willerr(a, dtype.Floats, nilTC) _, ok := q.Engine().(Inver) we = we || !ok // we'll exclude everything other than floats - if err := typeclassCheck(a.Dtype(), floatTypes); err != nil { + if err := dtype.TypeClassCheck(a.Dtype(), dtype.Floats); err != nil { return true } ret, err := Inv(a, UseUnsafe()) @@ -577,12 +578,12 @@ func TestInv(t *testing.T) { correct := a.Clone().(*Dense) reuse := a.Clone().(*Dense) reuse.Zero() - we, willFailEq := willerr(a, floatTypes, nil) + we, willFailEq := willerr(a, dtype.Floats, nilTC) _, ok := q.Engine().(Inver) we = we || !ok // we'll exclude everything other than floats - if err := typeclassCheck(a.Dtype(), floatTypes); err != nil { + if err := dtype.TypeClassCheck(a.Dtype(), dtype.Floats); err != nil { return true } ret, err := Inv(a, WithReuse(reuse)) @@ -613,12 +614,12 @@ func TestInv(t *testing.T) { correct := a.Clone().(*Dense) incr.Memset(identityVal(100, a.t)) correct.Add(incr, UseUnsafe()) - we, willFailEq := willerr(a, floatTypes, nil) + we, willFailEq := willerr(a, dtype.Floats, nilTC) _, ok := q.Engine().(Inver) we = we || !ok // we'll exclude everything other than floats - if err := typeclassCheck(a.Dtype(), floatTypes); err != nil { + if err := dtype.TypeClassCheck(a.Dtype(), dtype.Floats); err != nil { return true } ret, err := Inv(a, WithIncr(incr)) @@ -654,12 +655,12 @@ func TestLog10(t *testing.T) { invFn := func(q *Dense) bool { a := q.Clone().(*Dense) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, floatTypes, nil) + we, willFailEq := willerr(a, dtype.Floats, nilTC) _, ok := q.Engine().(Log10er) we = we || !ok // we'll exclude everything other than floats - if err := typeclassCheck(a.Dtype(), floatTypes); err != nil { + if err := dtype.TypeClassCheck(a.Dtype(), dtype.Floats); err != nil { return true } ret, err := Log10(a) @@ -683,18 +684,17 @@ func TestLog10(t *testing.T) { t.Errorf("Inv tests for Log10 failed: %v", err) } - // unsafe invFn = func(q *Dense) bool { a := q.Clone().(*Dense) b := q.Clone().(*Dense) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, floatTypes, nil) + we, willFailEq := willerr(a, dtype.Floats, nilTC) _, ok := q.Engine().(Log10er) we = we || !ok // we'll exclude everything other than floats - if err := typeclassCheck(a.Dtype(), floatTypes); err != nil { + if err := dtype.TypeClassCheck(a.Dtype(), dtype.Floats); err != nil { return true } ret, err := Log10(a, UseUnsafe()) @@ -720,19 +720,18 @@ func TestLog10(t *testing.T) { t.Errorf("Inv tests using unsafe for Log10 failed: %v", err) } - // reuse invFn = func(q *Dense) bool { a := q.Clone().(*Dense) correct := a.Clone().(*Dense) reuse := a.Clone().(*Dense) reuse.Zero() - we, willFailEq := willerr(a, floatTypes, nil) + we, willFailEq := willerr(a, dtype.Floats, nilTC) _, ok := q.Engine().(Log10er) we = we || !ok // we'll exclude everything other than floats - if err := typeclassCheck(a.Dtype(), floatTypes); err != nil { + if err := dtype.TypeClassCheck(a.Dtype(), dtype.Floats); err != nil { return true } ret, err := Log10(a, WithReuse(reuse)) @@ -764,12 +763,12 @@ func TestLog10(t *testing.T) { correct := a.Clone().(*Dense) incr.Memset(identityVal(100, a.t)) correct.Add(incr, UseUnsafe()) - we, willFailEq := willerr(a, floatTypes, nil) + we, willFailEq := willerr(a, dtype.Floats, nilTC) _, ok := q.Engine().(Log10er) we = we || !ok // we'll exclude everything other than floats - if err := typeclassCheck(a.Dtype(), floatTypes); err != nil { + if err := dtype.TypeClassCheck(a.Dtype(), dtype.Floats); err != nil { return true } ret, err := Log10(a, WithIncr(incr)) @@ -808,10 +807,10 @@ func TestAbs(t *testing.T) { correct := New(Of(Bool), WithShape(q.Shape().Clone()...)) correct.Memset(true) // we'll exclude everything other than ordtypes because complex numbers cannot be abs'd - if err := typeclassCheck(a.Dtype(), ordTypes); err != nil { + if err := dtype.TypeClassCheck(a.Dtype(), dtype.Ord); err != nil { return true } - we, willFailEq := willerr(a, signedTypes, nil) + we, willFailEq := willerr(a, dtype.Signed, nilTC) _, ok := q.Engine().(Abser) we = we || !ok @@ -836,19 +835,18 @@ func TestAbs(t *testing.T) { } } - func TestTanh(t *testing.T) { var r *rand.Rand // default invFn := func(q *Dense) bool { a := q.Clone().(*Dense) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, floatTypes, nil) + we, willFailEq := willerr(a, dtype.Floats, nilTC) _, ok := q.Engine().(Tanher) we = we || !ok // we'll exclude everything other than floats - if err := typeclassCheck(a.Dtype(), floatTypes); err != nil { + if err := dtype.TypeClassCheck(a.Dtype(), dtype.Floats); err != nil { return true } ret, err := Tanh(a) @@ -885,12 +883,12 @@ func TestTanh(t *testing.T) { invFn = func(q *Dense) bool { a := q.Clone().(*Dense) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, floatTypes, nil) + we, willFailEq := willerr(a, dtype.Floats, nilTC) _, ok := q.Engine().(Tanher) we = we || !ok // we'll exclude everything other than floats - if err := typeclassCheck(a.Dtype(), floatTypes); err != nil { + if err := dtype.TypeClassCheck(a.Dtype(), dtype.Floats); err != nil { return true } ret, err := Tanh(a, UseUnsafe()) @@ -926,19 +924,18 @@ func TestTanh(t *testing.T) { t.Errorf("Inv tests using unsafe for Tanh failed: %v", err) } - // reuse invFn = func(q *Dense) bool { a := q.Clone().(*Dense) correct := a.Clone().(*Dense) reuse := a.Clone().(*Dense) reuse.Zero() - we, willFailEq := willerr(a, floatTypes, nil) + we, willFailEq := willerr(a, dtype.Floats, nilTC) _, ok := q.Engine().(Tanher) we = we || !ok // we'll exclude everything other than floats - if err := typeclassCheck(a.Dtype(), floatTypes); err != nil { + if err := dtype.TypeClassCheck(a.Dtype(), dtype.Floats); err != nil { return true } ret, err := Tanh(a, WithReuse(reuse)) @@ -973,7 +970,6 @@ func TestTanh(t *testing.T) { t.Errorf("Inv tests using unsafe for Tanh failed: %v", err) } - // incr invFn = func(q *Dense) bool { a := q.Clone().(*Dense) @@ -981,12 +977,12 @@ func TestTanh(t *testing.T) { correct := a.Clone().(*Dense) incr.Memset(identityVal(100, a.t)) correct.Add(incr, UseUnsafe()) - we, willFailEq := willerr(a, floatTypes, nil) + we, willFailEq := willerr(a, dtype.Floats, nilTC) _, ok := q.Engine().(Tanher) we = we || !ok // we'll exclude everything other than floats - if err := typeclassCheck(a.Dtype(), floatTypes); err != nil { + if err := dtype.TypeClassCheck(a.Dtype(), dtype.Floats); err != nil { return true } ret, err := Tanh(a, WithIncr(incr)) @@ -1033,12 +1029,12 @@ func TestLog2(t *testing.T) { invFn := func(q *Dense) bool { a := q.Clone().(*Dense) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, floatTypes, nil) + we, willFailEq := willerr(a, dtype.Floats, nilTC) _, ok := q.Engine().(Log2er) we = we || !ok // we'll exclude everything other than floats - if err := typeclassCheck(a.Dtype(), floatTypes); err != nil { + if err := dtype.TypeClassCheck(a.Dtype(), dtype.Floats); err != nil { return true } ret, err := Log2(a) @@ -1062,18 +1058,17 @@ func TestLog2(t *testing.T) { t.Errorf("Inv tests for Log2 failed: %v", err) } - // unsafe invFn = func(q *Dense) bool { a := q.Clone().(*Dense) b := q.Clone().(*Dense) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, floatTypes, nil) + we, willFailEq := willerr(a, dtype.Floats, nilTC) _, ok := q.Engine().(Log2er) we = we || !ok // we'll exclude everything other than floats - if err := typeclassCheck(a.Dtype(), floatTypes); err != nil { + if err := dtype.TypeClassCheck(a.Dtype(), dtype.Floats); err != nil { return true } ret, err := Log2(a, UseUnsafe()) @@ -1099,19 +1094,18 @@ func TestLog2(t *testing.T) { t.Errorf("Inv tests using unsafe for Log2 failed: %v", err) } - // reuse invFn = func(q *Dense) bool { a := q.Clone().(*Dense) correct := a.Clone().(*Dense) reuse := a.Clone().(*Dense) reuse.Zero() - we, willFailEq := willerr(a, floatTypes, nil) + we, willFailEq := willerr(a, dtype.Floats, nilTC) _, ok := q.Engine().(Log2er) we = we || !ok // we'll exclude everything other than floats - if err := typeclassCheck(a.Dtype(), floatTypes); err != nil { + if err := dtype.TypeClassCheck(a.Dtype(), dtype.Floats); err != nil { return true } ret, err := Log2(a, WithReuse(reuse)) @@ -1143,12 +1137,12 @@ func TestLog2(t *testing.T) { correct := a.Clone().(*Dense) incr.Memset(identityVal(100, a.t)) correct.Add(incr, UseUnsafe()) - we, willFailEq := willerr(a, floatTypes, nil) + we, willFailEq := willerr(a, dtype.Floats, nilTC) _, ok := q.Engine().(Log2er) we = we || !ok // we'll exclude everything other than floats - if err := typeclassCheck(a.Dtype(), floatTypes); err != nil { + if err := dtype.TypeClassCheck(a.Dtype(), dtype.Floats); err != nil { return true } ret, err := Log2(a, WithIncr(incr)) @@ -1177,4 +1171,4 @@ func TestLog2(t *testing.T) { t.Errorf("Inv tests using unsafe for Log2 failed: %v", err) } -} \ No newline at end of file +} diff --git a/api_utils.go b/api_utils.go index 12bea19..3cf55f0 100644 --- a/api_utils.go +++ b/api_utils.go @@ -53,11 +53,11 @@ func SortIndex(in interface{}) (out []int) { // SampleIndex samples a slice or a Tensor. // TODO: tidy this up. func SampleIndex(in interface{}) int { - var l int + // var l int switch list := in.(type) { case []int: var sum, i int - l = len(list) + // l = len(list) r := rand.Int() for { sum += list[i] @@ -69,7 +69,7 @@ func SampleIndex(in interface{}) int { case []float64: var sum float64 var i int - l = len(list) + // l = len(list) r := rand.Float64() for { sum += list[i] @@ -85,7 +85,7 @@ func SampleIndex(in interface{}) int { var sum float64 r := rand.Float64() data := list.Float64s() - l = len(data) + // l = len(data) for { datum := data[i] if math.IsNaN(datum) || math.IsInf(datum, 0) { @@ -102,7 +102,7 @@ func SampleIndex(in interface{}) int { var sum float32 r := rand.Float32() data := list.Float32s() - l = len(data) + // l = len(data) for { datum := data[i] if math32.IsNaN(datum) || math32.IsInf(datum, 0) { @@ -121,5 +121,5 @@ func SampleIndex(in interface{}) int { default: panic("Not yet implemented") } - return l - 1 + return -1 } diff --git a/array.go b/array.go index 62fa8a2..e805405 100644 --- a/array.go +++ b/array.go @@ -3,53 +3,31 @@ package tensor import ( "fmt" "reflect" + "sync" "unsafe" "github.com/pkg/errors" + "gorgonia.org/dtype" "gorgonia.org/tensor/internal/storage" ) // array is the underlying generic array. type array struct { storage.Header // the header - the Go representation (a slice) - t Dtype // the element type - v interface{} // an additional reference to the underlying slice. This is not strictly necessary, but does improve upon anything that calls .Data() -} - -// makeHeader makes a array Header -func makeHeader(t Dtype, length int) storage.Header { - size := int(calcMemSize(t, length)) - s := make([]byte, size) - return storage.Header{ - Ptr: unsafe.Pointer(&s[0]), - L: length, - C: length, - } + t dtype.Dtype // the element type } // makeArray makes an array. The memory allocation is handled by Go -func makeArray(t Dtype, length int) array { - hdr := makeHeader(t, length) - return makeArrayFromHeader(hdr, t) -} - -// makeArrayFromHeader makes an array given a header -func makeArrayFromHeader(hdr storage.Header, t Dtype) array { - // build a type of []T - shdr := reflect.SliceHeader{ - Data: uintptr(hdr.Ptr), - Len: hdr.L, - Cap: hdr.C, +func makeArray(t dtype.Dtype, length int) array { + v := malloc(t, length) + hdr := storage.Header{ + Raw: v, } - sliceT := reflect.SliceOf(t.Type) - ptr := unsafe.Pointer(&shdr) - val := reflect.Indirect(reflect.NewAt(sliceT, ptr)) - return array{ Header: hdr, t: t, - v: val.Interface(), } + } // arrayFromSlice creates an array from a slice. If x is not a slice, it will panic. @@ -60,103 +38,82 @@ func arrayFromSlice(x interface{}) array { } elT := xT.Elem() - xV := reflect.ValueOf(x) - ptr := xV.Pointer() - uptr := unsafe.Pointer(ptr) - return array{ Header: storage.Header{ - Ptr: uptr, - L: xV.Len(), - C: xV.Cap(), + Raw: storage.AsByteSlice(x), }, - t: Dtype{elT}, - v: x, + t: dtype.Dtype{elT}, } } +func (a *array) Len() int { return a.Header.TypedLen(a.t.Type) } + +func (a *array) Cap() int { return a.Header.TypedLen(a.t.Type) } + +// fromSlice populates the value from a slice func (a *array) fromSlice(x interface{}) { xT := reflect.TypeOf(x) if xT.Kind() != reflect.Slice { panic("Expected a slice") } elT := xT.Elem() - xV := reflect.ValueOf(x) - ptr := xV.Pointer() - uptr := unsafe.Pointer(ptr) - a.Ptr = uptr - a.L = xV.Len() - a.C = xV.Cap() - a.t = Dtype{elT} - a.v = x + a.Raw = storage.AsByteSlice(x) + a.t = dtype.Dtype{elT} } -func (a *array) fix() { - if a.v == nil { - shdr := reflect.SliceHeader{ - Data: uintptr(a.Ptr), - Len: a.L, - Cap: a.C, +// fromSliceOrTensor populates the value from a slice or anything that can form an array +func (a *array) fromSliceOrArrayer(x interface{}) { + if T, ok := x.(arrayer); ok { + xp := T.arrPtr() + + // if the underlying array hasn't been allocated, or not enough has been allocated + if a.Header.Raw == nil { + a.Header.Raw = malloc(xp.t, xp.Len()) } - sliceT := reflect.SliceOf(a.t.Type) - ptr := unsafe.Pointer(&shdr) - val := reflect.Indirect(reflect.NewAt(sliceT, ptr)) - a.v = val.Interface() + + a.t = xp.t + copyArray(a, T.arrPtr()) + return } + a.fromSlice(x) } // byteSlice casts the underlying slice into a byte slice. Useful for copying and zeroing, but not much else -func (a array) byteSlice() []byte { - return storage.AsByteSlice(&a.Header, a.t.Type) -} +func (a array) byteSlice() []byte { return a.Header.Raw } // sliceInto creates a slice. Instead of returning an array, which would cause a lot of reallocations, sliceInto expects a array to // already have been created. This allows repetitive actions to be done without having to have many pointless allocation func (a *array) sliceInto(i, j int, res *array) { - base := uintptr(a.Ptr) - c := a.C + c := a.Cap() if i < 0 || j < i || j > c { panic(fmt.Sprintf("Cannot slice %v - index %d:%d is out of bounds", a, i, j)) } - res.L = j - i - res.C = c - i + s := i * int(a.t.Size()) + e := j * int(a.t.Size()) + c = c - i + + res.Raw = a.Raw[s:e] - if c-1 > 0 { - res.Ptr = storage.ElementAt(i, unsafe.Pointer(base), a.t.Size()) - } else { - // don't advance pointer - res.Ptr = unsafe.Pointer(base) - } - res.fix() } +// slice slices an array func (a array) slice(start, end int) array { - if end > a.L { + if end > a.Len() { panic("Index out of range") } if end < start { panic("Index out of range") } - L := end - start - C := a.C - start + s := start * int(a.t.Size()) + e := end * int(a.t.Size()) - var startptr unsafe.Pointer - if a.C-start > 0 { - startptr = storage.ElementAt(start, a.Ptr, a.t.Size()) - } else { - startptr = a.Ptr - } - - hdr := storage.Header{ - Ptr: startptr, - L: L, - C: C, + return array{ + Header: storage.Header{Raw: a.Raw[s:e]}, + t: a.t, } - - return makeArrayFromHeader(hdr, a.t) } // swap swaps the elements i and j in the array @@ -199,16 +156,25 @@ func (a *array) swap(i, j int) { /* *Array is a Memory */ // Uintptr returns the pointer of the first value of the slab -func (t *array) Uintptr() uintptr { return uintptr(t.Ptr) } +func (a *array) Uintptr() uintptr { return uintptr(unsafe.Pointer(&a.Header.Raw[0])) } // MemSize returns how big the slice is in bytes -func (t *array) MemSize() uintptr { return uintptr(t.L) * t.t.Size() } - -// Pointer returns the pointer of the first value of the slab, as an unsafe.Pointer -func (t *array) Pointer() unsafe.Pointer { return t.Ptr } +func (a *array) MemSize() uintptr { return uintptr(len(a.Header.Raw)) } // Data returns the representation of a slice. -func (a array) Data() interface{} { return a.v } +func (a array) Data() interface{} { + // build a type of []T + shdr := reflect.SliceHeader{ + Data: a.Uintptr(), + Len: a.Len(), + Cap: a.Cap(), + } + sliceT := reflect.SliceOf(a.t.Type) + ptr := unsafe.Pointer(&shdr) + val := reflect.Indirect(reflect.NewAt(sliceT, ptr)) + return val.Interface() + +} // Zero zeroes out the underlying array of the *Dense tensor. func (a array) Zero() { @@ -226,10 +192,10 @@ func (a array) Zero() { } return } - ptr := uintptr(a.Ptr) - for i := 0; i < a.L; i++ { - want := ptr + uintptr(i)*a.t.Size() - val := reflect.NewAt(a.t, unsafe.Pointer(want)) + + l := a.Len() + for i := 0; i < l; i++ { + val := reflect.NewAt(a.t.Type, storage.ElementAt(i, unsafe.Pointer(&a.Header.Raw[0]), a.t.Size())) val = reflect.Indirect(val) val.Set(reflect.Zero(a.t)) } @@ -240,8 +206,14 @@ func (a *array) rtype() reflect.Type { return a.t.Type } /* MEMORY MOVEMENT STUFF */ +// malloc is standard Go allocation of a block of memory - the plus side is that Go manages the memory +func malloc(t dtype.Dtype, length int) []byte { + size := int(calcMemSize(t, length)) + return make([]byte, size) +} + // calcMemSize calulates the memory size of an array (given its size) -func calcMemSize(dt Dtype, size int) int64 { +func calcMemSize(dt dtype.Dtype, size int) int64 { return int64(dt.Size()) * int64(size) } @@ -288,6 +260,7 @@ func copyDense(dst, src DenseTensor) int { // return copyArray(dst.arr(), src.arr()) } +// copyDenseSliced copies a DenseTensor, but both are sliced func copyDenseSliced(dst DenseTensor, dstart, dend int, src DenseTensor, sstart, send int) int { if dst.Dtype() != src.Dtype() { panic("Cannot copy DenseTensors of different types") @@ -306,36 +279,62 @@ func copyDenseSliced(dst DenseTensor, dstart, dend int, src DenseTensor, sstart, } } if e := src.Engine(); e != nil { - d := dst.arr().slice(dstart, dend) - s := src.arr().slice(sstart, send) - if err := e.Memcpy(&d, &s); err != nil { - panic(err) + darr := dst.arr() + sarr := src.arr() + da := darr.slice(dstart, dend) + sa := sarr.slice(sstart, send) + + switch e.(type) { + case NonStdEngine: + if err := e.Memcpy(&da, &sa); err != nil { + panic(err) + } + default: + // THIS IS AN OPTIMIZATION. REVISIT WHEN NEEDED. + // + // THE PURPOSE of this optimization is to make this perform better under + // default circumstances. + // + // The original code simply uses t.Engine().Memcpy(&dSlice, &tSlice). + // A variant can still be seen in the NonStdEngine case above. + // + // The `array.slice()` method has been optimized to return `array2`, which is a + // non-heap allocated type. + // a value of `array2` cannot have its address taken - e.g. + // var a array2 + // doSomething(&a) // ← this cannot be done + // + // We *could* make `array2` implement Memory. But then a lot of runtime.convT2I and + // runtime.convI2T would be called. Which defeats the purpose of making things fast. + // + // So instead, we check to see if the Engine uses standard allocation methods. + // Typically this means `StdEng`. + // + // If so, we directly use storage.Copy instead of using the engine + storage.Copy(da.t.Type, &da.Header, &sa.Header) } - return d.Len() + + return da.Len() } return copyArraySliced(dst.arr(), dstart, dend, src.arr(), sstart, send) } +// copyDenseIter copies a DenseTensor, with iterator func copyDenseIter(dst, src DenseTensor, diter, siter Iterator) (int, error) { if dst.Dtype() != src.Dtype() { panic("Cannot copy Dense arrays of different types") } - // do not use requiresIterator because requiresIterator has particular optimizations for operations (like add, sub etc) - var dstOK, srcOK bool - if dstView, ok := dst.(View); ok && dstView.IsMaterializable() { - srcOK = true - } - if srcView, ok := src.(View); ok && srcView.IsMaterializable() { - dstOK = true - } - - if !dstOK && !srcOK { + // if they all don't need iterators, and have the same data order + if !dst.RequiresIterator() && !src.RequiresIterator() && dst.DataOrder().HasSameOrder(src.DataOrder()) { return copyDense(dst, src), nil } - if !dst.IsNativelyAccessible() || !src.IsNativelyAccessible() { - return 0, errors.Errorf(inaccessibleData, "copy") + if !dst.IsNativelyAccessible() { + return 0, errors.Errorf(inaccessibleData, dst) + } + if !src.IsNativelyAccessible() { + return 0, errors.Errorf(inaccessibleData, src) } if diter == nil { @@ -345,6 +344,7 @@ func copyDenseIter(dst, src DenseTensor, diter, siter Iterator) (int, error) { siter = FlatIteratorFromDense(src) } + // if it's a masked tensor, we copy the mask as well if ms, ok := src.(MaskedTensor); ok && ms.IsMasked() { if md, ok := dst.(MaskedTensor); ok { dmask := md.Mask() @@ -360,66 +360,79 @@ func copyDenseIter(dst, src DenseTensor, diter, siter Iterator) (int, error) { return storage.CopyIter(dst.rtype(), dst.hdr(), src.hdr(), diter, siter), nil } -func getPointer(a interface{}) unsafe.Pointer { +type scalarPtrCount struct { + Ptr unsafe.Pointer + Count int +} + +// scalarRCLock is a lock for the reference counting list. +var scalarRCLock sync.Mutex + +// scalarRC is a bunch of reference counted pointers to scalar values +var scalarRC = make(map[uintptr]*sync.Pool) // uintptr is the size, the pool stores []byte + +func scalarPool(size uintptr) *sync.Pool { + scalarRCLock.Lock() + pool, ok := scalarRC[size] + if !ok { + pool = &sync.Pool{ + New: func() interface{} { return make([]byte, size) }, + } + scalarRC[size] = pool + } + scalarRCLock.Unlock() + return pool +} + +func allocScalar(a interface{}) []byte { + atype := reflect.TypeOf(a) + size := atype.Size() + pool := scalarPool(size) + return pool.Get().([]byte) +} + +func freeScalar(bs []byte) { + if bs == nil { + return + } + + // zero out + for i := range bs { + bs[i] = 0 + } + + size := uintptr(len(bs)) + + // put it back into pool + pool := scalarPool(size) + pool.Put(bs) +} + +// scalarToHeader creates a Header from a scalar value +func scalarToHeader(a interface{}) (hdr *storage.Header, newAlloc bool) { + var raw []byte switch at := a.(type) { case Memory: - return at.Pointer() - case bool: - return unsafe.Pointer(&at) - case int: - return unsafe.Pointer(&at) - case int8: - return unsafe.Pointer(&at) - case int16: - return unsafe.Pointer(&at) - case int32: - return unsafe.Pointer(&at) - case int64: - return unsafe.Pointer(&at) - case uint: - return unsafe.Pointer(&at) - case uint8: - return unsafe.Pointer(&at) - case uint16: - return unsafe.Pointer(&at) - case uint32: - return unsafe.Pointer(&at) - case uint64: - return unsafe.Pointer(&at) - case float32: - return unsafe.Pointer(&at) - case float64: - return unsafe.Pointer(&at) - case complex64: - return unsafe.Pointer(&at) - case complex128: - return unsafe.Pointer(&at) - case string: - return unsafe.Pointer(&at) - case uintptr: - return unsafe.Pointer(&at) - case unsafe.Pointer: - return at - - // POINTERS - - case *float32: - return unsafe.Pointer(at) - case *float64: - return unsafe.Pointer(at) - case *complex64: - return unsafe.Pointer(at) - case *complex128: - return unsafe.Pointer(at) - } - - panic("Cannot get pointer") + raw = storage.FromMemory(at.Uintptr(), at.MemSize()) + default: + raw = allocScalar(a) + newAlloc = true + } + hdr = borrowHeader() + hdr.Raw = raw + if newAlloc { + copyScalarToPrealloc(a, hdr.Raw) + } + + return hdr, newAlloc } -func scalarToHeader(a interface{}) *storage.Header { - hdr := borrowHeader() - hdr.Ptr = getPointer(a) - hdr.L = 1 - hdr.C = 1 - return hdr +func copyScalarToPrealloc(a interface{}, bs []byte) { + xV := reflect.ValueOf(a) + xT := reflect.TypeOf(a) + + p := unsafe.Pointer(&bs[0]) + v := reflect.NewAt(xT, p) + reflect.Indirect(v).Set(xV) + return } diff --git a/array_getset.go b/array_getset.go index fe65438..69bcf95 100644 --- a/array_getset.go +++ b/array_getset.go @@ -1,5 +1,3 @@ -// Code generated by genlib2. DO NOT EDIT. - package tensor import ( @@ -7,8 +5,11 @@ import ( "unsafe" "github.com/pkg/errors" + "gorgonia.org/tensor/internal/storage" ) +// Code generated by genlib2. DO NOT EDIT. + // Set sets the value of the underlying array at the index i. func (a *array) Set(i int, x interface{}) { switch a.t.Kind() { @@ -68,9 +69,7 @@ func (a *array) Set(i int, x interface{}) { a.SetUnsafePointer(i, xv) default: xv := reflect.ValueOf(x) - ptr := uintptr(a.Ptr) - want := ptr + uintptr(i)*a.t.Size() - val := reflect.NewAt(a.t, unsafe.Pointer(want)) + val := reflect.NewAt(a.t.Type, storage.ElementAt(i, unsafe.Pointer(&a.Header.Raw[0]), a.t.Size())) val = reflect.Indirect(val) val.Set(xv) } @@ -81,43 +80,60 @@ func (a *array) Get(i int) interface{} { switch a.t.Kind() { case reflect.Bool: return a.GetB(i) + case reflect.Int: return a.GetI(i) + case reflect.Int8: return a.GetI8(i) + case reflect.Int16: return a.GetI16(i) + case reflect.Int32: return a.GetI32(i) + case reflect.Int64: return a.GetI64(i) + case reflect.Uint: return a.GetU(i) + case reflect.Uint8: return a.GetU8(i) + case reflect.Uint16: return a.GetU16(i) + case reflect.Uint32: return a.GetU32(i) + case reflect.Uint64: return a.GetU64(i) + case reflect.Uintptr: return a.GetUintptr(i) + case reflect.Float32: return a.GetF32(i) + case reflect.Float64: return a.GetF64(i) + case reflect.Complex64: return a.GetC64(i) + case reflect.Complex128: return a.GetC128(i) + case reflect.String: return a.GetStr(i) + case reflect.UnsafePointer: return a.GetUnsafePointer(i) + default: - at := uintptr(a.Ptr) + uintptr(i)*a.t.Size() - val := reflect.NewAt(a.t, unsafe.Pointer(at)) + val := reflect.NewAt(a.t.Type, storage.ElementAt(i, unsafe.Pointer(&a.Header.Raw[0]), a.t.Size())) val = reflect.Indirect(val) return val.Interface() } @@ -291,25 +307,24 @@ func (a *array) Memset(x interface{}) error { } xv := reflect.ValueOf(x) - ptr := uintptr(a.Ptr) - for i := 0; i < a.L; i++ { - want := ptr + uintptr(i)*a.t.Size() - val := reflect.NewAt(a.t, unsafe.Pointer(want)) + l := a.Len() + for i := 0; i < l; i++ { + val := reflect.NewAt(a.t.Type, storage.ElementAt(i, unsafe.Pointer(&a.Header.Raw[0]), a.t.Size())) val = reflect.Indirect(val) val.Set(xv) } return nil } -func (t *array) memsetIter(x interface{}, it Iterator) (err error) { +func (a *array) memsetIter(x interface{}, it Iterator) (err error) { var i int - switch t.t { + switch a.t { case Bool: xv, ok := x.(bool) if !ok { - return errors.Errorf(dtypeMismatch, t.t, x) + return errors.Errorf(dtypeMismatch, a.t, x) } - data := t.Bools() + data := a.Bools() for i, err = it.Next(); err == nil; i, err = it.Next() { data[i] = xv } @@ -317,9 +332,9 @@ func (t *array) memsetIter(x interface{}, it Iterator) (err error) { case Int: xv, ok := x.(int) if !ok { - return errors.Errorf(dtypeMismatch, t.t, x) + return errors.Errorf(dtypeMismatch, a.t, x) } - data := t.Ints() + data := a.Ints() for i, err = it.Next(); err == nil; i, err = it.Next() { data[i] = xv } @@ -327,9 +342,9 @@ func (t *array) memsetIter(x interface{}, it Iterator) (err error) { case Int8: xv, ok := x.(int8) if !ok { - return errors.Errorf(dtypeMismatch, t.t, x) + return errors.Errorf(dtypeMismatch, a.t, x) } - data := t.Int8s() + data := a.Int8s() for i, err = it.Next(); err == nil; i, err = it.Next() { data[i] = xv } @@ -337,9 +352,9 @@ func (t *array) memsetIter(x interface{}, it Iterator) (err error) { case Int16: xv, ok := x.(int16) if !ok { - return errors.Errorf(dtypeMismatch, t.t, x) + return errors.Errorf(dtypeMismatch, a.t, x) } - data := t.Int16s() + data := a.Int16s() for i, err = it.Next(); err == nil; i, err = it.Next() { data[i] = xv } @@ -347,9 +362,9 @@ func (t *array) memsetIter(x interface{}, it Iterator) (err error) { case Int32: xv, ok := x.(int32) if !ok { - return errors.Errorf(dtypeMismatch, t.t, x) + return errors.Errorf(dtypeMismatch, a.t, x) } - data := t.Int32s() + data := a.Int32s() for i, err = it.Next(); err == nil; i, err = it.Next() { data[i] = xv } @@ -357,9 +372,9 @@ func (t *array) memsetIter(x interface{}, it Iterator) (err error) { case Int64: xv, ok := x.(int64) if !ok { - return errors.Errorf(dtypeMismatch, t.t, x) + return errors.Errorf(dtypeMismatch, a.t, x) } - data := t.Int64s() + data := a.Int64s() for i, err = it.Next(); err == nil; i, err = it.Next() { data[i] = xv } @@ -367,9 +382,9 @@ func (t *array) memsetIter(x interface{}, it Iterator) (err error) { case Uint: xv, ok := x.(uint) if !ok { - return errors.Errorf(dtypeMismatch, t.t, x) + return errors.Errorf(dtypeMismatch, a.t, x) } - data := t.Uints() + data := a.Uints() for i, err = it.Next(); err == nil; i, err = it.Next() { data[i] = xv } @@ -377,9 +392,9 @@ func (t *array) memsetIter(x interface{}, it Iterator) (err error) { case Uint8: xv, ok := x.(uint8) if !ok { - return errors.Errorf(dtypeMismatch, t.t, x) + return errors.Errorf(dtypeMismatch, a.t, x) } - data := t.Uint8s() + data := a.Uint8s() for i, err = it.Next(); err == nil; i, err = it.Next() { data[i] = xv } @@ -387,9 +402,9 @@ func (t *array) memsetIter(x interface{}, it Iterator) (err error) { case Uint16: xv, ok := x.(uint16) if !ok { - return errors.Errorf(dtypeMismatch, t.t, x) + return errors.Errorf(dtypeMismatch, a.t, x) } - data := t.Uint16s() + data := a.Uint16s() for i, err = it.Next(); err == nil; i, err = it.Next() { data[i] = xv } @@ -397,9 +412,9 @@ func (t *array) memsetIter(x interface{}, it Iterator) (err error) { case Uint32: xv, ok := x.(uint32) if !ok { - return errors.Errorf(dtypeMismatch, t.t, x) + return errors.Errorf(dtypeMismatch, a.t, x) } - data := t.Uint32s() + data := a.Uint32s() for i, err = it.Next(); err == nil; i, err = it.Next() { data[i] = xv } @@ -407,9 +422,9 @@ func (t *array) memsetIter(x interface{}, it Iterator) (err error) { case Uint64: xv, ok := x.(uint64) if !ok { - return errors.Errorf(dtypeMismatch, t.t, x) + return errors.Errorf(dtypeMismatch, a.t, x) } - data := t.Uint64s() + data := a.Uint64s() for i, err = it.Next(); err == nil; i, err = it.Next() { data[i] = xv } @@ -417,9 +432,9 @@ func (t *array) memsetIter(x interface{}, it Iterator) (err error) { case Uintptr: xv, ok := x.(uintptr) if !ok { - return errors.Errorf(dtypeMismatch, t.t, x) + return errors.Errorf(dtypeMismatch, a.t, x) } - data := t.Uintptrs() + data := a.Uintptrs() for i, err = it.Next(); err == nil; i, err = it.Next() { data[i] = xv } @@ -427,9 +442,9 @@ func (t *array) memsetIter(x interface{}, it Iterator) (err error) { case Float32: xv, ok := x.(float32) if !ok { - return errors.Errorf(dtypeMismatch, t.t, x) + return errors.Errorf(dtypeMismatch, a.t, x) } - data := t.Float32s() + data := a.Float32s() for i, err = it.Next(); err == nil; i, err = it.Next() { data[i] = xv } @@ -437,9 +452,9 @@ func (t *array) memsetIter(x interface{}, it Iterator) (err error) { case Float64: xv, ok := x.(float64) if !ok { - return errors.Errorf(dtypeMismatch, t.t, x) + return errors.Errorf(dtypeMismatch, a.t, x) } - data := t.Float64s() + data := a.Float64s() for i, err = it.Next(); err == nil; i, err = it.Next() { data[i] = xv } @@ -447,9 +462,9 @@ func (t *array) memsetIter(x interface{}, it Iterator) (err error) { case Complex64: xv, ok := x.(complex64) if !ok { - return errors.Errorf(dtypeMismatch, t.t, x) + return errors.Errorf(dtypeMismatch, a.t, x) } - data := t.Complex64s() + data := a.Complex64s() for i, err = it.Next(); err == nil; i, err = it.Next() { data[i] = xv } @@ -457,9 +472,9 @@ func (t *array) memsetIter(x interface{}, it Iterator) (err error) { case Complex128: xv, ok := x.(complex128) if !ok { - return errors.Errorf(dtypeMismatch, t.t, x) + return errors.Errorf(dtypeMismatch, a.t, x) } - data := t.Complex128s() + data := a.Complex128s() for i, err = it.Next(); err == nil; i, err = it.Next() { data[i] = xv } @@ -467,9 +482,9 @@ func (t *array) memsetIter(x interface{}, it Iterator) (err error) { case String: xv, ok := x.(string) if !ok { - return errors.Errorf(dtypeMismatch, t.t, x) + return errors.Errorf(dtypeMismatch, a.t, x) } - data := t.Strings() + data := a.Strings() for i, err = it.Next(); err == nil; i, err = it.Next() { data[i] = xv } @@ -477,19 +492,17 @@ func (t *array) memsetIter(x interface{}, it Iterator) (err error) { case UnsafePointer: xv, ok := x.(unsafe.Pointer) if !ok { - return errors.Errorf(dtypeMismatch, t.t, x) + return errors.Errorf(dtypeMismatch, a.t, x) } - data := t.UnsafePointers() + data := a.UnsafePointers() for i, err = it.Next(); err == nil; i, err = it.Next() { data[i] = xv } err = handleNoOp(err) default: xv := reflect.ValueOf(x) - ptr := uintptr(t.Ptr) for i, err = it.Next(); err == nil; i, err = it.Next() { - want := ptr + uintptr(i)*t.t.Size() - val := reflect.NewAt(t.t, unsafe.Pointer(want)) + val := reflect.NewAt(a.t.Type, storage.ElementAt(i, unsafe.Pointer(&a.Header.Raw[0]), a.t.Size())) val = reflect.Indirect(val) val.Set(xv) } @@ -505,16 +518,17 @@ func (a array) Eq(other interface{}) bool { return false } - if oa.L != a.L { - return false - } - - if oa.C != a.C { + if oa.Len() != a.Len() { return false } + /* + if oa.C != a.C { + return false + } + */ // same exact thing - if uintptr(oa.Ptr) == uintptr(a.Ptr) { + if uintptr(unsafe.Pointer(&oa.Header.Raw[0])) == uintptr(unsafe.Pointer(&a.Header.Raw[0])) { return true } @@ -628,7 +642,7 @@ func (a array) Eq(other interface{}) bool { } } default: - for i := 0; i < a.L; i++ { + for i := 0; i < a.Len(); i++ { if !reflect.DeepEqual(a.Get(i), oa.Get(i)) { return false } @@ -639,124 +653,122 @@ func (a array) Eq(other interface{}) bool { return false } -func (t *array) zeroIter(it Iterator) (err error) { +func (a *array) zeroIter(it Iterator) (err error) { var i int - switch t.t { + switch a.t { case Bool: - data := t.Bools() + data := a.Bools() for i, err = it.Next(); err == nil; i, err = it.Next() { data[i] = false } err = handleNoOp(err) case Int: - data := t.Ints() + data := a.Ints() for i, err = it.Next(); err == nil; i, err = it.Next() { data[i] = 0 } err = handleNoOp(err) case Int8: - data := t.Int8s() + data := a.Int8s() for i, err = it.Next(); err == nil; i, err = it.Next() { data[i] = 0 } err = handleNoOp(err) case Int16: - data := t.Int16s() + data := a.Int16s() for i, err = it.Next(); err == nil; i, err = it.Next() { data[i] = 0 } err = handleNoOp(err) case Int32: - data := t.Int32s() + data := a.Int32s() for i, err = it.Next(); err == nil; i, err = it.Next() { data[i] = 0 } err = handleNoOp(err) case Int64: - data := t.Int64s() + data := a.Int64s() for i, err = it.Next(); err == nil; i, err = it.Next() { data[i] = 0 } err = handleNoOp(err) case Uint: - data := t.Uints() + data := a.Uints() for i, err = it.Next(); err == nil; i, err = it.Next() { data[i] = 0 } err = handleNoOp(err) case Uint8: - data := t.Uint8s() + data := a.Uint8s() for i, err = it.Next(); err == nil; i, err = it.Next() { data[i] = 0 } err = handleNoOp(err) case Uint16: - data := t.Uint16s() + data := a.Uint16s() for i, err = it.Next(); err == nil; i, err = it.Next() { data[i] = 0 } err = handleNoOp(err) case Uint32: - data := t.Uint32s() + data := a.Uint32s() for i, err = it.Next(); err == nil; i, err = it.Next() { data[i] = 0 } err = handleNoOp(err) case Uint64: - data := t.Uint64s() + data := a.Uint64s() for i, err = it.Next(); err == nil; i, err = it.Next() { data[i] = 0 } err = handleNoOp(err) case Uintptr: - data := t.Uintptrs() + data := a.Uintptrs() for i, err = it.Next(); err == nil; i, err = it.Next() { data[i] = 0 } err = handleNoOp(err) case Float32: - data := t.Float32s() + data := a.Float32s() for i, err = it.Next(); err == nil; i, err = it.Next() { data[i] = 0 } err = handleNoOp(err) case Float64: - data := t.Float64s() + data := a.Float64s() for i, err = it.Next(); err == nil; i, err = it.Next() { data[i] = 0 } err = handleNoOp(err) case Complex64: - data := t.Complex64s() + data := a.Complex64s() for i, err = it.Next(); err == nil; i, err = it.Next() { data[i] = 0 } err = handleNoOp(err) case Complex128: - data := t.Complex128s() + data := a.Complex128s() for i, err = it.Next(); err == nil; i, err = it.Next() { data[i] = 0 } err = handleNoOp(err) case String: - data := t.Strings() + data := a.Strings() for i, err = it.Next(); err == nil; i, err = it.Next() { data[i] = "" } err = handleNoOp(err) case UnsafePointer: - data := t.UnsafePointers() + data := a.UnsafePointers() for i, err = it.Next(); err == nil; i, err = it.Next() { data[i] = nil } err = handleNoOp(err) default: - ptr := uintptr(t.Ptr) for i, err = it.Next(); err == nil; i, err = it.Next() { - want := ptr + uintptr(i)*t.t.Size() - val := reflect.NewAt(t.t, unsafe.Pointer(want)) + val := reflect.NewAt(a.t.Type, storage.ElementAt(i, unsafe.Pointer(&a.Header.Raw[0]), a.t.Size())) val = reflect.Indirect(val) - val.Set(reflect.Zero(t.t)) + val.Set(reflect.Zero(a.t)) } err = handleNoOp(err) } diff --git a/bench.sh b/bench.sh new file mode 100755 index 0000000..8523853 --- /dev/null +++ b/bench.sh @@ -0,0 +1,23 @@ +#!/bin/sh + +old=$1; +new=$2; + +git checkout $old +# https://stackoverflow.com/a/2111099 +branch=$(git symbolic-ref HEAD | sed -e 's,.*/\(.*\),\1,') +echo "Benchmarking $branch (old)" +go test -run=$^ -bench=. > ${branch}.bench +for i in {1..10} + do + go test -run=$^ -bench=. >> ${branch}.bench + done + +git checkout $new +branch=$(git symbolic-ref HEAD | sed -e 's,.*/\(.*\),\1,') +echo "Benchmarking $branch (new)" +go test -run=$^ -bench=. > ${branch}.bench +for i in {1..10} + do + go test -run=$^ -bench=. >> ${branch}.bench + done diff --git a/benchmark_dense_matop_test.go b/benchmark_dense_matop_test.go index 15a77ab..2a4ee4a 100644 --- a/benchmark_dense_matop_test.go +++ b/benchmark_dense_matop_test.go @@ -1,6 +1,9 @@ package tensor -import "testing" +import ( + "math/rand" + "testing" +) func BenchmarkDense_Transpose(b *testing.B) { T := New(WithShape(100, 100, 2), WithBacking(Range(Byte, 0, 100*100*2))) @@ -64,8 +67,78 @@ func BenchmarkGetWithIterator(b *testing.B) { f = data[next] } if _, ok := err.(NoOpError); !ok { - b.Error("Error: %v", err) + b.Errorf("Error: %v", err) } } _ = f } + +func BenchmarkComplicatedGet(b *testing.B) { + T := New(WithShape(101, 1, 36, 5), Of(Float64)) + T.T(0, 2, 1, 3) + data := T.Data().([]float64) + var f float64 + b.ResetTimer() + for i := 0; i < b.N; i++ { + it := IteratorFromDense(T) + var next int + + var err error + for next, err = it.Start(); err == nil; next, err = it.Next() { + f = data[next] + } + if _, ok := err.(NoOpError); !ok { + b.Errorf("Error: %v", err) + } + } + _ = f +} + +var atCoords [10000][2]int + +func init() { + for i := range atCoords { + atCoords[i][0] = rand.Intn(100) + atCoords[i][1] = rand.Intn(100) + } +} + +var at1, at2 float64 + +// func BenchmarkAtWithNativeIterator(b *testing.B) { +// T := New(WithShape(100, 100), Of(Float64)) +// it, err := NativeMatrixF64(T) +// if err != nil { +// b.Fatalf("Error: %v", err) +// } + +// var j int +// for i := 0; i < b.N; i++ { + +// if j >= len(atCoords) { +// j = 0 +// } + +// at := atCoords[j] +// at1 = it[at[0]][at[1]] +// j++ +// } +// } + +func BenchmarkAt(b *testing.B) { + T := New(WithShape(100, 100), Of(Float64)) + var j int + for i := 0; i < b.N; i++ { + if j >= len(atCoords) { + j = 0 + } + + at := atCoords[j] + _, err := T.At(at[0], at[1]) + if err != nil { + b.Errorf("Error: %v", err) + } + + j++ + } +} diff --git a/benchmark_dense_repeat_test.go b/benchmark_dense_repeat_test.go new file mode 100644 index 0000000..f22138a --- /dev/null +++ b/benchmark_dense_repeat_test.go @@ -0,0 +1,14 @@ +package tensor + +import "testing" + +func BenchmarkDenseRepeat(b *testing.B) { + for _, tst := range repeatTests { + tst := tst + b.Run(tst.name, func(b *testing.B) { + for n := 0; n < b.N; n++ { + tst.tensor.Repeat(tst.axis, tst.repeats...) + } + }) + } +} diff --git a/blas.go b/blas.go index c708400..9bc170b 100644 --- a/blas.go +++ b/blas.go @@ -15,8 +15,8 @@ var whichblas BLAS type BLAS interface { blas.Float32 blas.Float64 - // blas.Complex64 - // blas.Complex128 + blas.Complex64 + blas.Complex128 } // only blastoise.Implementation() and cubone.Implementation() are batchedBLAS - diff --git a/collections.go b/collections.go index 5f4d075..34e8284 100644 --- a/collections.go +++ b/collections.go @@ -1,30 +1,30 @@ -package tensor - -import "github.com/pkg/errors" - -func densesToTensors(a []*Dense) []Tensor { - retVal := make([]Tensor, len(a)) - for i, t := range a { - retVal[i] = t - } - return retVal -} - -func densesToDenseTensors(a []*Dense) []DenseTensor { - retVal := make([]DenseTensor, len(a)) - for i, t := range a { - retVal[i] = t - } - return retVal -} - -func tensorsToDenseTensors(a []Tensor) ([]DenseTensor, error) { - retVal := make([]DenseTensor, len(a)) - var ok bool - for i, t := range a { - if retVal[i], ok = t.(DenseTensor); !ok { - return nil, errors.Errorf("can only convert Tensors of the same type to DenseTensors. Trying to convert %T (#%d in slice)", t, i) - } - } - return retVal, nil -} +package tensor + +import "github.com/pkg/errors" + +func densesToTensors(a []*Dense) []Tensor { + retVal := make([]Tensor, len(a)) + for i, t := range a { + retVal[i] = t + } + return retVal +} + +func densesToDenseTensors(a []*Dense) []DenseTensor { + retVal := make([]DenseTensor, len(a)) + for i, t := range a { + retVal[i] = t + } + return retVal +} + +func tensorsToDenseTensors(a []Tensor) ([]DenseTensor, error) { + retVal := make([]DenseTensor, len(a)) + var ok bool + for i, t := range a { + if retVal[i], ok = t.(DenseTensor); !ok { + return nil, errors.Errorf("can only convert Tensors of the same type to DenseTensors. Trying to convert %T (#%d in slice)", t, i) + } + } + return retVal, nil +} diff --git a/consopt.go b/consopt.go index 9118cad..4b2de84 100644 --- a/consopt.go +++ b/consopt.go @@ -2,14 +2,17 @@ package tensor import ( "reflect" - "unsafe" + + "gorgonia.org/dtype" + "gorgonia.org/tensor/internal/storage" ) // ConsOpt is a tensor construction option. type ConsOpt func(Tensor) // Of is a construction option for a Tensor. -func Of(a Dtype) ConsOpt { +func Of(a dtype.Dtype) ConsOpt { + dtype.Register(a) f := func(t Tensor) { switch tt := t.(type) { case *Dense: @@ -53,7 +56,7 @@ func WithBacking(x interface{}, argMask ...[]bool) ConsOpt { // WithMask is a construction option for a Tensor // Use it as such: // mask := []bool{true,true,false,false} -// t := New(WithBacking(backing)) +// t := New(WithBacking(backing), WithMask(mask)) // It can be used with other construction options like WithShape // The supplied mask can be any type. If non-boolean, then tensor mask is set to true // wherever non-zero value is obtained @@ -105,18 +108,13 @@ func FromScalar(x interface{}, argMask ...[]bool) ConsOpt { f := func(t Tensor) { switch tt := t.(type) { case *Dense: - xt := reflect.TypeOf(x) - xv := reflect.New(xt) - xvi := reflect.Indirect(xv) - xvi.Set(reflect.ValueOf(x)) - ptr := xv.Pointer() - uptr := unsafe.Pointer(ptr) - - tt.array.Ptr = uptr - tt.array.L = 1 - tt.array.C = 1 - tt.v = x - tt.t = Dtype{xt} + xT := reflect.TypeOf(x) + sxT := reflect.SliceOf(xT) + xv := reflect.MakeSlice(sxT, 1, 1) // []T + xv0 := xv.Index(0) // xv[0] + xv0.Set(reflect.ValueOf(x)) + tt.array.Header.Raw = storage.AsByteSlice(xv.Interface()) + tt.t = dtype.Dtype{xT} tt.mask = mask default: @@ -140,22 +138,14 @@ func FromScalar(x interface{}, argMask ...[]bool) ConsOpt { // Memory must be manually managed by the caller. // Tensors called with this construction option will not be returned to any pool - rather, all references to the pointers will be null'd. // Use with caution. +//go:nocheckptr func FromMemory(ptr uintptr, memsize uintptr) ConsOpt { f := func(t Tensor) { switch tt := t.(type) { case *Dense: - tt.v = nil // if there were any underlying slices it should be GC'd - - tt.array.Ptr = unsafe.Pointer(ptr) - tt.array.L = int(memsize / tt.t.Size()) - tt.array.C = int(memsize / tt.t.Size()) - + tt.Header.Raw = nil // GC anything if needed + tt.Header.Raw = storage.FromMemory(ptr, memsize) tt.flag = MakeMemoryFlag(tt.flag, ManuallyManaged) - - if tt.IsNativelyAccessible() { - tt.array.fix() - } - default: panic("Unsupported Tensor type") } @@ -172,9 +162,11 @@ func WithEngine(e Engine) ConsOpt { if e != nil && !e.AllocAccessible() { tt.flag = MakeMemoryFlag(tt.flag, NativelyInaccessible) } - // if oe, ok := e.(standardEngine); ok { - // tt.oe = oe - // } + + tt.oe = nil + if oe, ok := e.(StandardEngine); ok { + tt.oe = oe + } case *CS: tt.e = e if e != nil && !e.AllocAccessible() { @@ -185,14 +177,81 @@ func WithEngine(e Engine) ConsOpt { return f } -func AsFortran() ConsOpt { +// AsFortran creates a *Dense with a col-major layout. +// If the optional backing argument is passed, the backing is assumed to be C-order (row major), and +// it will be transposed before being used. +func AsFortran(backing interface{}, argMask ...[]bool) ConsOpt { + var mask []bool + if len(argMask) > 0 { + mask = argMask[0] + } f := func(t Tensor) { switch tt := t.(type) { case *Dense: - if tt.AP == nil { - // create AP + if backing != nil { + // put the data into the tensor, then make a clone tensor to transpose + tt.fromSliceOrArrayer(backing) + // create a temporary tensor, to which the transpose will be done + tmp := NewDense(tt.Dtype(), tt.shape.Clone()) + copyArray(tmp.arrPtr(), tt.arrPtr()) + tmp.SetMask(mask) + tmp.T() + tmp.Transpose() + // copy the data back to the current tensor + copyArray(tt.arrPtr(), tmp.arrPtr()) + tt.SetMask(tmp.Mask()) + // cleanup: return the temporary tensor back to the pool + ReturnTensor(tmp) } + tt.AP.o = MakeDataOrder(tt.AP.o, ColMajor) + if tt.AP.shape != nil { + ReturnInts(tt.AP.strides) + tt.AP.strides = nil + tt.AP.strides = tt.AP.calcStrides() + } + case *CS: + panic("AsFortran is not an available option for Compressed Sparse layouts") + } + } + return f +} + +func AsDenseDiag(backing interface{}) ConsOpt { + f := func(t Tensor) { + switch tt := t.(type) { + case *Dense: + if bt, ok := backing.(Tensor); ok { + backing = bt.Data() + } + xT := reflect.TypeOf(backing) + if xT.Kind() != reflect.Slice { + panic("Expected a slice") + } + xV := reflect.ValueOf(backing) + l := xV.Len() + // elT := xT.Elem() + + sli := reflect.MakeSlice(xT, l*l, l*l) + + shape := Shape{l, l} + strides := CalcStrides(shape) + for i := 0; i < l; i++ { + idx, err := Ltoi(shape, strides, i, i) + if err != nil { + panic(err) + } + + at := sli.Index(idx) + xi := xV.Index(i) + at.Set(xi) + } + + tt.fromSliceOrArrayer(sli.Interface()) + tt.setShape(l, l) + + default: + panic("AsDenseDiag is not available as an option for CS") } } return f diff --git a/consopt_test.go b/consopt_test.go new file mode 100644 index 0000000..67ad664 --- /dev/null +++ b/consopt_test.go @@ -0,0 +1,96 @@ +// +build linux + +package tensor + +import ( + "fmt" + "io/ioutil" + "os" + "syscall" + "testing" + "testing/quick" + "unsafe" + + "github.com/stretchr/testify/assert" +) + +type F64 float64 + +func newF64(f float64) *F64 { r := F64(f); return &r } + +func (f *F64) Uintptr() uintptr { return uintptr(unsafe.Pointer(f)) } + +func (f *F64) MemSize() uintptr { return 8 } + +func (f *F64) Pointer() unsafe.Pointer { return unsafe.Pointer(f) } + +func Test_FromMemory(t *testing.T) { + fn := func(F float64) bool { + f := newF64(F) + T := New(WithShape(), Of(Float64), FromMemory(f.Uintptr(), f.MemSize())) + data := T.Data().(float64) + + if data != F { + return false + } + return true + } + if err := quick.Check(fn, &quick.Config{MaxCount: 1000000}); err != nil { + t.Logf("%v", err) + } + + f, err := ioutil.TempFile("", "test") + if err != nil { + t.Fatal(err) + } + // fill in with fake data + backing := make([]byte, 8*1024*1024) // 1024*1024 matrix of float64 + asFloats := *(*[]float64)(unsafe.Pointer(&backing)) + asFloats = asFloats[: 1024*1024 : 1024*1024] + asFloats[0] = 3.14 + asFloats[2] = 6.28 + asFloats[1024*1024-1] = 3.14 + asFloats[1024*1024-3] = 6.28 + f.Write(backing) + + // defer cleanup + defer os.Remove(f.Name()) + + // do the mmap stuff + stat, err := f.Stat() + if err != nil { + t.Fatal(err) + } + + size := int(stat.Size()) + fd := int(f.Fd()) + bs, err := syscall.Mmap(fd, 0, size, syscall.PROT_READ, syscall.MAP_SHARED) + if err != nil { + t.Fatal(err) + } + defer func() { + if err := syscall.Munmap(bs); err != nil { + t.Error(err) + } + }() + T := New(WithShape(1024, 1024), Of(Float64), FromMemory(uintptr(unsafe.Pointer(&bs[0])), uintptr(size))) + + s := fmt.Sprintf("%v", T) + expected := `⎡3.14 0 6.28 0 ... 0 0 0 0⎤ +⎢ 0 0 0 0 ... 0 0 0 0⎥ +⎢ 0 0 0 0 ... 0 0 0 0⎥ +⎢ 0 0 0 0 ... 0 0 0 0⎥ +. +. +. +⎢ 0 0 0 0 ... 0 0 0 0⎥ +⎢ 0 0 0 0 ... 0 0 0 0⎥ +⎢ 0 0 0 0 ... 0 0 0 0⎥ +⎣ 0 0 0 0 ... 0 6.28 0 3.14⎦ +` + if s != expected { + t.Errorf("Expected mmap'd tensor to be exactly the same.") + } + + assert.True(t, T.IsManuallyManaged()) +} diff --git a/defaultengine.go b/defaultengine.go index 6dd1f45..5338391 100644 --- a/defaultengine.go +++ b/defaultengine.go @@ -1,81 +1,76 @@ -package tensor - -import ( - "unsafe" - - "github.com/pkg/errors" - "gorgonia.org/tensor/internal/execution" -) - -// StdEng is the default execution engine that comes with the tensors. To use other execution engines, use the WithEngine construction option. -type StdEng struct { - execution.E -} - -// makeArray allocates a slice for the array -func (e StdEng) makeArray(arr *array, t Dtype, size int) { - memsize := calcMemSize(t, size) - s := make([]byte, memsize) - arr.t = t - arr.L = size - arr.C = size - arr.Ptr = unsafe.Pointer(&s[0]) - arr.fix() -} - -func (e StdEng) AllocAccessible() bool { return true } -func (e StdEng) Alloc(size int64) (Memory, error) { return nil, noopError{} } -func (e StdEng) Free(mem Memory, size int64) error { return nil } -func (e StdEng) Memset(mem Memory, val interface{}) error { - if ms, ok := mem.(MemSetter); ok { - return ms.Memset(val) - } - return errors.Errorf("Cannot memset %v with StdEng") -} - -func (e StdEng) Memclr(mem Memory) { - if z, ok := mem.(Zeroer); ok { - z.Zero() - } - return -} - -func (e StdEng) Memcpy(dst, src Memory) error { - switch dt := dst.(type) { - case *array: - switch st := src.(type) { - case *array: - copyArray(dt, st) - return nil - case arrayer: - copyArray(dt, st.arrPtr()) - return nil - } - case arrayer: - switch st := src.(type) { - case *array: - copyArray(dt.arrPtr(), st) - return nil - case arrayer: - copyArray(dt.arrPtr(), st.arrPtr()) - return nil - } - } - return errors.Errorf("Failed to copy %T %T", dst, src) -} - -func (e StdEng) Accessible(mem Memory) (Memory, error) { return mem, nil } - -func (e StdEng) WorksWith(order DataOrder) bool { - if order.isColMajor() { - return false - } - return true -} - -func (e StdEng) checkAccessible(t Tensor) error { - if !t.IsNativelyAccessible() { - return errors.Errorf(inaccessibleData, t) - } - return nil -} +package tensor + +import ( + "github.com/pkg/errors" + "gorgonia.org/dtype" + "gorgonia.org/tensor/internal/execution" +) + +// stdDenseEng is the default execution engine for dense tensor operations. +type stdDenseEng struct { + execution.E +} + +// StdEng is the default execution engine that comes with the tensors. To use other execution engines, use the WithEngine construction option. +type StdEng struct { + stdDenseEng +} + +// makeArray allocates a slice for the array +func (e StdEng) makeArray(arr *array, t dtype.Dtype, size int) { + arr.Raw = malloc(t, size) + arr.t = t +} + +func (e StdEng) AllocAccessible() bool { return true } +func (e StdEng) Alloc(size int64) (Memory, error) { return nil, noopError{} } + +func (e StdEng) Free(mem Memory, size int64) error { return nil } +func (e StdEng) Memset(mem Memory, val interface{}) error { + if ms, ok := mem.(MemSetter); ok { + return ms.Memset(val) + } + return errors.Errorf("Cannot memset %v with StdEng", mem) +} + +func (e StdEng) Memclr(mem Memory) { + if z, ok := mem.(Zeroer); ok { + z.Zero() + } + return +} + +func (e StdEng) Memcpy(dst, src Memory) error { + switch dt := dst.(type) { + case *array: + switch st := src.(type) { + case *array: + copyArray(dt, st) + return nil + case arrayer: + copyArray(dt, st.arrPtr()) + return nil + } + case arrayer: + switch st := src.(type) { + case *array: + copyArray(dt.arrPtr(), st) + return nil + case arrayer: + copyArray(dt.arrPtr(), st.arrPtr()) + return nil + } + } + return errors.Errorf("Failed to copy %T %T", dst, src) +} + +func (e StdEng) Accessible(mem Memory) (Memory, error) { return mem, nil } + +func (e StdEng) WorksWith(order DataOrder) bool { return true } + +func (e StdEng) checkAccessible(t Tensor) error { + if !t.IsNativelyAccessible() { + return errors.Errorf(inaccessibleData, t) + } + return nil +} diff --git a/defaultengine_argmethods.go b/defaultengine_argmethods.go index 3cedd84..0bb1707 100644 --- a/defaultengine_argmethods.go +++ b/defaultengine_argmethods.go @@ -1,19 +1,27 @@ package tensor -import "github.com/pkg/errors" +import ( + "context" -func (e StdEng) Argmax(t Tensor, axis int) (retVal Tensor, err error) { + "github.com/pkg/errors" + "gorgonia.org/dtype" +) + +func (e StdEng) Argmax(ctx context.Context, t Tensor, axis int) (retVal Tensor, err error) { switch tt := t.(type) { case DenseTensor: - return e.argmaxDenseTensor(tt, axis) + return e.argmaxDenseTensor(ctx, tt, axis) default: - return nil, errors.Errorf(typeNYI, "StdEng.Argmax", t) + return nil, nyierr(typeNYI, t) } } -func (e StdEng) argmaxDenseTensor(t DenseTensor, axis int) (retVal *Dense, err error) { - if err = unaryCheck(t, ordTypes); err != nil { +func (e StdEng) argmaxDenseTensor(ctx context.Context, t DenseTensor, axis int) (retVal *Dense, err error) { + if err = handleCtx(ctx); err != nil { + return nil, err + } + if err = unaryCheck(t, dtype.Ord); err != nil { return nil, errors.Wrapf(err, opFail, "Argmax") } @@ -59,17 +67,21 @@ func (e StdEng) argmaxDenseTensor(t DenseTensor, axis int) (retVal *Dense, err e if _, ok := err.(NoOpError); !ok && err != nil { return } else if ok { - newAP = t.Info().Clone() + t.Info().CloneTo(&newAP) } - defer ReturnAP(newAP) it := IteratorFromDense(t) - iteratorLoadAP(it, newAP) + iteratorLoadAP(it, &newAP) lastSize := it.Shape()[len(it.Shape())-1] newShape := it.Shape().Clone() newShape = newShape[:len(newShape)-1] - defer ReturnInts(newShape) + + // cleanup + defer func() { + newAP.zero() + ReturnInts(newShape) + }() if mt, ok := t.(MaskedTensor); ok && mt.IsMasked() { mask := mt.Mask() @@ -85,18 +97,21 @@ func (e StdEng) argmaxDenseTensor(t DenseTensor, axis int) (retVal *Dense, err e return New(WithShape(newShape...), WithBacking(indices)), nil } -func (e StdEng) Argmin(t Tensor, axis int) (retVal Tensor, err error) { +func (e StdEng) Argmin(ctx context.Context, t Tensor, axis int) (retVal Tensor, err error) { switch tt := t.(type) { case DenseTensor: - return e.argminDenseTensor(tt, axis) + return e.argminDenseTensor(ctx, tt, axis) default: - return nil, errors.Errorf(typeNYI, "StdEng.Argmin", t) + return nil, nyierr(typeNYI, t) } } -func (e StdEng) argminDenseTensor(t DenseTensor, axis int) (retVal *Dense, err error) { - if err = unaryCheck(t, ordTypes); err != nil { +func (e StdEng) argminDenseTensor(ctx context.Context, t DenseTensor, axis int) (retVal *Dense, err error) { + if err = handleCtx(ctx); err != nil { + return nil, err + } + if err = unaryCheck(t, dtype.Ord); err != nil { return nil, errors.Wrapf(err, opFail, "Argmin") } @@ -144,15 +159,19 @@ func (e StdEng) argminDenseTensor(t DenseTensor, axis int) (retVal *Dense, err e } else if ok { newAP = t.Info().Clone() } - defer ReturnAP(newAP) it := IteratorFromDense(t) - iteratorLoadAP(it, newAP) + iteratorLoadAP(it, &newAP) lastSize := it.Shape()[len(it.Shape())-1] newShape := it.Shape().Clone() newShape = newShape[:len(newShape)-1] - defer ReturnInts(newShape) + + // cleanup + defer func() { + newAP.zero() + ReturnInts(newShape) + }() if mt, ok := t.(MaskedTensor); ok && mt.IsMasked() { mask := mt.Mask() diff --git a/defaultengine_arith.go b/defaultengine_arith.go index 01d9784..369475d 100644 --- a/defaultengine_arith.go +++ b/defaultengine_arith.go @@ -1,24 +1,32 @@ -// Code generated by genlib2. DO NOT EDIT. - package tensor import ( + "context" + "github.com/pkg/errors" + "gorgonia.org/dtype" "gorgonia.org/tensor/internal/storage" ) +// Code generated by genlib2. DO NOT EDIT. + // Add performs a + b elementwise. Both a and b must have the same shape. // Acceptable FuncOpts are: UseUnsafe(), WithReuse(T), WithIncr(T) func (e StdEng) Add(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err error) { - if err = binaryCheck(a, b, numberTypes); err != nil { + if err = binaryCheck(a, b, dtype.Number); err != nil { + return nil, errors.Wrapf(err, "Add failed") } var reuse DenseTensor var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), true, opts...); err != nil { + var ctx context.Context + if ctx, reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be noopError{}, no need to wrap. + } typ := a.Dtype().Type var dataA, dataB, dataReuse *storage.Header var ait, bit, iit Iterator @@ -48,7 +56,6 @@ func (e StdEng) Add(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err err } err = e.E.AddIter(typ, retVal.hdr(), dataB, ait, bit) } - return } switch { @@ -56,8 +63,7 @@ func (e StdEng) Add(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err err err = e.E.AddIncr(typ, dataA, dataB, dataReuse) retVal = reuse case toReuse: - storage.Copy(typ, dataReuse, dataA) - err = e.E.Add(typ, dataReuse, dataB) + err = e.E.AddRecv(typ, dataA, dataB, dataReuse) retVal = reuse case !safe: err = e.E.Add(typ, dataA, dataB) @@ -70,22 +76,26 @@ func (e StdEng) Add(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err err } err = e.E.Add(typ, retVal.hdr(), dataB) } - return } // Sub performs a - b elementwise. Both a and b must have the same shape. // Acceptable FuncOpts are: UseUnsafe(), WithReuse(T), WithIncr(T) func (e StdEng) Sub(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err error) { - if err = binaryCheck(a, b, numberTypes); err != nil { + if err = binaryCheck(a, b, dtype.Number); err != nil { + return nil, errors.Wrapf(err, "Sub failed") } var reuse DenseTensor var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), true, opts...); err != nil { + var ctx context.Context + if ctx, reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be noopError{}, no need to wrap. + } typ := a.Dtype().Type var dataA, dataB, dataReuse *storage.Header var ait, bit, iit Iterator @@ -115,7 +125,6 @@ func (e StdEng) Sub(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err err } err = e.E.SubIter(typ, retVal.hdr(), dataB, ait, bit) } - return } switch { @@ -123,8 +132,7 @@ func (e StdEng) Sub(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err err err = e.E.SubIncr(typ, dataA, dataB, dataReuse) retVal = reuse case toReuse: - storage.Copy(typ, dataReuse, dataA) - err = e.E.Sub(typ, dataReuse, dataB) + err = e.E.SubRecv(typ, dataA, dataB, dataReuse) retVal = reuse case !safe: err = e.E.Sub(typ, dataA, dataB) @@ -137,22 +145,26 @@ func (e StdEng) Sub(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err err } err = e.E.Sub(typ, retVal.hdr(), dataB) } - return } // Mul performs a × b elementwise. Both a and b must have the same shape. // Acceptable FuncOpts are: UseUnsafe(), WithReuse(T), WithIncr(T) func (e StdEng) Mul(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err error) { - if err = binaryCheck(a, b, numberTypes); err != nil { + if err = binaryCheck(a, b, dtype.Number); err != nil { + return nil, errors.Wrapf(err, "Mul failed") } var reuse DenseTensor var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), true, opts...); err != nil { + var ctx context.Context + if ctx, reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be noopError{}, no need to wrap. + } typ := a.Dtype().Type var dataA, dataB, dataReuse *storage.Header var ait, bit, iit Iterator @@ -182,7 +194,6 @@ func (e StdEng) Mul(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err err } err = e.E.MulIter(typ, retVal.hdr(), dataB, ait, bit) } - return } switch { @@ -190,8 +201,7 @@ func (e StdEng) Mul(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err err err = e.E.MulIncr(typ, dataA, dataB, dataReuse) retVal = reuse case toReuse: - storage.Copy(typ, dataReuse, dataA) - err = e.E.Mul(typ, dataReuse, dataB) + err = e.E.MulRecv(typ, dataA, dataB, dataReuse) retVal = reuse case !safe: err = e.E.Mul(typ, dataA, dataB) @@ -204,22 +214,26 @@ func (e StdEng) Mul(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err err } err = e.E.Mul(typ, retVal.hdr(), dataB) } - return } // Div performs a ÷ b elementwise. Both a and b must have the same shape. // Acceptable FuncOpts are: UseUnsafe(), WithReuse(T), WithIncr(T) func (e StdEng) Div(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err error) { - if err = binaryCheck(a, b, numberTypes); err != nil { + if err = binaryCheck(a, b, dtype.Number); err != nil { + return nil, errors.Wrapf(err, "Div failed") } var reuse DenseTensor var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), true, opts...); err != nil { + var ctx context.Context + if ctx, reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be noopError{}, no need to wrap. + } typ := a.Dtype().Type var dataA, dataB, dataReuse *storage.Header var ait, bit, iit Iterator @@ -249,7 +263,6 @@ func (e StdEng) Div(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err err } err = e.E.DivIter(typ, retVal.hdr(), dataB, ait, bit) } - return } switch { @@ -257,8 +270,7 @@ func (e StdEng) Div(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err err err = e.E.DivIncr(typ, dataA, dataB, dataReuse) retVal = reuse case toReuse: - storage.Copy(typ, dataReuse, dataA) - err = e.E.Div(typ, dataReuse, dataB) + err = e.E.DivRecv(typ, dataA, dataB, dataReuse) retVal = reuse case !safe: err = e.E.Div(typ, dataA, dataB) @@ -271,22 +283,26 @@ func (e StdEng) Div(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err err } err = e.E.Div(typ, retVal.hdr(), dataB) } - return } // Pow performs a ^ b elementwise. Both a and b must have the same shape. // Acceptable FuncOpts are: UseUnsafe(), WithReuse(T), WithIncr(T) func (e StdEng) Pow(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err error) { - if err = binaryCheck(a, b, numberTypes); err != nil { + if err = binaryCheck(a, b, dtype.Number); err != nil { + return nil, errors.Wrapf(err, "Pow failed") } var reuse DenseTensor var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), true, opts...); err != nil { + var ctx context.Context + if ctx, reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be noopError{}, no need to wrap. + } typ := a.Dtype().Type var dataA, dataB, dataReuse *storage.Header var ait, bit, iit Iterator @@ -316,7 +332,6 @@ func (e StdEng) Pow(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err err } err = e.E.PowIter(typ, retVal.hdr(), dataB, ait, bit) } - return } switch { @@ -324,8 +339,7 @@ func (e StdEng) Pow(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err err err = e.E.PowIncr(typ, dataA, dataB, dataReuse) retVal = reuse case toReuse: - storage.Copy(typ, dataReuse, dataA) - err = e.E.Pow(typ, dataReuse, dataB) + err = e.E.PowRecv(typ, dataA, dataB, dataReuse) retVal = reuse case !safe: err = e.E.Pow(typ, dataA, dataB) @@ -338,22 +352,26 @@ func (e StdEng) Pow(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err err } err = e.E.Pow(typ, retVal.hdr(), dataB) } - return } // Mod performs a % b elementwise. Both a and b must have the same shape. // Acceptable FuncOpts are: UseUnsafe(), WithReuse(T), WithIncr(T) func (e StdEng) Mod(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err error) { - if err = binaryCheck(a, b, numberTypes); err != nil { + if err = binaryCheck(a, b, dtype.Number); err != nil { + return nil, errors.Wrapf(err, "Mod failed") } var reuse DenseTensor var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), true, opts...); err != nil { + var ctx context.Context + if ctx, reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be noopError{}, no need to wrap. + } typ := a.Dtype().Type var dataA, dataB, dataReuse *storage.Header var ait, bit, iit Iterator @@ -383,7 +401,6 @@ func (e StdEng) Mod(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err err } err = e.E.ModIter(typ, retVal.hdr(), dataB, ait, bit) } - return } switch { @@ -391,8 +408,7 @@ func (e StdEng) Mod(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err err err = e.E.ModIncr(typ, dataA, dataB, dataReuse) retVal = reuse case toReuse: - storage.Copy(typ, dataReuse, dataA) - err = e.E.Mod(typ, dataReuse, dataB) + err = e.E.ModRecv(typ, dataA, dataB, dataReuse) retVal = reuse case !safe: err = e.E.Mod(typ, dataA, dataB) @@ -405,35 +421,42 @@ func (e StdEng) Mod(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err err } err = e.E.Mod(typ, retVal.hdr(), dataB) } - return } // AddScalar performs t + s elementwise. The leftTensor parameter indicates if the tensor is the left operand. Only scalar types are accepted in s. // Acceptable FuncOpts are: UseUnsafe(), WithReuse(T), WithIncr(T) func (e StdEng) AddScalar(t Tensor, s interface{}, leftTensor bool, opts ...FuncOpt) (retVal Tensor, err error) { - if err = unaryCheck(t, numberTypes); err != nil { + if err = unaryCheck(t, dtype.Number); err != nil { return nil, errors.Wrapf(err, "Add failed") } + if err = scalarDtypeCheck(t, s); err != nil { + return nil, errors.Wrap(err, "Add failed") + } + var reuse DenseTensor var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, _, err = handleFuncOpts(t.Shape(), t.Dtype(), true, opts...); err != nil { + var ctx context.Context + if ctx, reuse, safe, toReuse, incr, _, err = handleFuncOpts(t.Shape(), t.Dtype(), t.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be noopError{}, no need to wrap. + } a := t typ := t.Dtype().Type var ait, bit, iit Iterator var dataA, dataB, dataReuse, scalarHeader *storage.Header - var useIter bool + var useIter, newAlloc bool if leftTensor { - if dataA, dataB, dataReuse, ait, iit, useIter, err = prepDataVS(t, s, reuse); err != nil { + if dataA, dataB, dataReuse, ait, iit, useIter, newAlloc, err = prepDataVS(t, s, reuse); err != nil { return nil, errors.Wrapf(err, opFail, "StdEng.Add") } scalarHeader = dataB } else { - if dataA, dataB, dataReuse, bit, iit, useIter, err = prepDataSV(s, t, reuse); err != nil { + if dataA, dataB, dataReuse, bit, iit, useIter, newAlloc, err = prepDataSV(s, t, reuse); err != nil { return nil, errors.Wrapf(err, opFail, "StdEng.Add") } scalarHeader = dataA @@ -467,6 +490,9 @@ func (e StdEng) AddScalar(t Tensor, s interface{}, leftTensor bool, opts ...Func err = e.E.AddIter(typ, dataA, retVal.hdr(), ait, bit) } } + if newAlloc { + freeScalar(scalarHeader.Raw) + } returnHeader(scalarHeader) return } @@ -481,17 +507,25 @@ func (e StdEng) AddScalar(t Tensor, s interface{}, leftTensor bool, opts ...Func case toReuse && !leftTensor: storage.Copy(typ, dataReuse, dataB) err = e.E.Add(typ, dataA, dataReuse) + if t.Shape().IsScalarEquiv() { + storage.Copy(typ, dataReuse, dataA) + } retVal = reuse case !safe: err = e.E.Add(typ, dataA, dataB) + if t.Shape().IsScalarEquiv() && !leftTensor { + storage.Copy(typ, dataB, dataA) + } retVal = a default: retVal = a.Clone().(Tensor) - if leftTensor { - err = e.E.Add(typ, retVal.hdr(), dataB) - } else { - err = e.E.Add(typ, dataA, retVal.hdr()) + if !leftTensor { + storage.Fill(typ, retVal.hdr(), dataA) } + err = e.E.Add(typ, retVal.hdr(), dataB) + } + if newAlloc { + freeScalar(scalarHeader.Raw) } returnHeader(scalarHeader) return @@ -500,28 +534,36 @@ func (e StdEng) AddScalar(t Tensor, s interface{}, leftTensor bool, opts ...Func // SubScalar performs t - s elementwise. The leftTensor parameter indicates if the tensor is the left operand. Only scalar types are accepted in s. // Acceptable FuncOpts are: UseUnsafe(), WithReuse(T), WithIncr(T) func (e StdEng) SubScalar(t Tensor, s interface{}, leftTensor bool, opts ...FuncOpt) (retVal Tensor, err error) { - if err = unaryCheck(t, numberTypes); err != nil { + if err = unaryCheck(t, dtype.Number); err != nil { return nil, errors.Wrapf(err, "Sub failed") } + if err = scalarDtypeCheck(t, s); err != nil { + return nil, errors.Wrap(err, "Sub failed") + } + var reuse DenseTensor var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, _, err = handleFuncOpts(t.Shape(), t.Dtype(), true, opts...); err != nil { + var ctx context.Context + if ctx, reuse, safe, toReuse, incr, _, err = handleFuncOpts(t.Shape(), t.Dtype(), t.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be noopError{}, no need to wrap. + } a := t typ := t.Dtype().Type var ait, bit, iit Iterator var dataA, dataB, dataReuse, scalarHeader *storage.Header - var useIter bool + var useIter, newAlloc bool if leftTensor { - if dataA, dataB, dataReuse, ait, iit, useIter, err = prepDataVS(t, s, reuse); err != nil { + if dataA, dataB, dataReuse, ait, iit, useIter, newAlloc, err = prepDataVS(t, s, reuse); err != nil { return nil, errors.Wrapf(err, opFail, "StdEng.Sub") } scalarHeader = dataB } else { - if dataA, dataB, dataReuse, bit, iit, useIter, err = prepDataSV(s, t, reuse); err != nil { + if dataA, dataB, dataReuse, bit, iit, useIter, newAlloc, err = prepDataSV(s, t, reuse); err != nil { return nil, errors.Wrapf(err, opFail, "StdEng.Sub") } scalarHeader = dataA @@ -555,6 +597,9 @@ func (e StdEng) SubScalar(t Tensor, s interface{}, leftTensor bool, opts ...Func err = e.E.SubIter(typ, dataA, retVal.hdr(), ait, bit) } } + if newAlloc { + freeScalar(scalarHeader.Raw) + } returnHeader(scalarHeader) return } @@ -569,17 +614,25 @@ func (e StdEng) SubScalar(t Tensor, s interface{}, leftTensor bool, opts ...Func case toReuse && !leftTensor: storage.Copy(typ, dataReuse, dataB) err = e.E.Sub(typ, dataA, dataReuse) + if t.Shape().IsScalarEquiv() { + storage.Copy(typ, dataReuse, dataA) + } retVal = reuse case !safe: err = e.E.Sub(typ, dataA, dataB) + if t.Shape().IsScalarEquiv() && !leftTensor { + storage.Copy(typ, dataB, dataA) + } retVal = a default: retVal = a.Clone().(Tensor) - if leftTensor { - err = e.E.Sub(typ, retVal.hdr(), dataB) - } else { - err = e.E.Sub(typ, dataA, retVal.hdr()) + if !leftTensor { + storage.Fill(typ, retVal.hdr(), dataA) } + err = e.E.Sub(typ, retVal.hdr(), dataB) + } + if newAlloc { + freeScalar(scalarHeader.Raw) } returnHeader(scalarHeader) return @@ -588,28 +641,36 @@ func (e StdEng) SubScalar(t Tensor, s interface{}, leftTensor bool, opts ...Func // MulScalar performs t × s elementwise. The leftTensor parameter indicates if the tensor is the left operand. Only scalar types are accepted in s. // Acceptable FuncOpts are: UseUnsafe(), WithReuse(T), WithIncr(T) func (e StdEng) MulScalar(t Tensor, s interface{}, leftTensor bool, opts ...FuncOpt) (retVal Tensor, err error) { - if err = unaryCheck(t, numberTypes); err != nil { + if err = unaryCheck(t, dtype.Number); err != nil { return nil, errors.Wrapf(err, "Mul failed") } + if err = scalarDtypeCheck(t, s); err != nil { + return nil, errors.Wrap(err, "Mul failed") + } + var reuse DenseTensor var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, _, err = handleFuncOpts(t.Shape(), t.Dtype(), true, opts...); err != nil { + var ctx context.Context + if ctx, reuse, safe, toReuse, incr, _, err = handleFuncOpts(t.Shape(), t.Dtype(), t.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be noopError{}, no need to wrap. + } a := t typ := t.Dtype().Type var ait, bit, iit Iterator var dataA, dataB, dataReuse, scalarHeader *storage.Header - var useIter bool + var useIter, newAlloc bool if leftTensor { - if dataA, dataB, dataReuse, ait, iit, useIter, err = prepDataVS(t, s, reuse); err != nil { + if dataA, dataB, dataReuse, ait, iit, useIter, newAlloc, err = prepDataVS(t, s, reuse); err != nil { return nil, errors.Wrapf(err, opFail, "StdEng.Mul") } scalarHeader = dataB } else { - if dataA, dataB, dataReuse, bit, iit, useIter, err = prepDataSV(s, t, reuse); err != nil { + if dataA, dataB, dataReuse, bit, iit, useIter, newAlloc, err = prepDataSV(s, t, reuse); err != nil { return nil, errors.Wrapf(err, opFail, "StdEng.Mul") } scalarHeader = dataA @@ -643,6 +704,9 @@ func (e StdEng) MulScalar(t Tensor, s interface{}, leftTensor bool, opts ...Func err = e.E.MulIter(typ, dataA, retVal.hdr(), ait, bit) } } + if newAlloc { + freeScalar(scalarHeader.Raw) + } returnHeader(scalarHeader) return } @@ -657,17 +721,25 @@ func (e StdEng) MulScalar(t Tensor, s interface{}, leftTensor bool, opts ...Func case toReuse && !leftTensor: storage.Copy(typ, dataReuse, dataB) err = e.E.Mul(typ, dataA, dataReuse) + if t.Shape().IsScalarEquiv() { + storage.Copy(typ, dataReuse, dataA) + } retVal = reuse case !safe: err = e.E.Mul(typ, dataA, dataB) + if t.Shape().IsScalarEquiv() && !leftTensor { + storage.Copy(typ, dataB, dataA) + } retVal = a default: retVal = a.Clone().(Tensor) - if leftTensor { - err = e.E.Mul(typ, retVal.hdr(), dataB) - } else { - err = e.E.Mul(typ, dataA, retVal.hdr()) + if !leftTensor { + storage.Fill(typ, retVal.hdr(), dataA) } + err = e.E.Mul(typ, retVal.hdr(), dataB) + } + if newAlloc { + freeScalar(scalarHeader.Raw) } returnHeader(scalarHeader) return @@ -676,28 +748,36 @@ func (e StdEng) MulScalar(t Tensor, s interface{}, leftTensor bool, opts ...Func // DivScalar performs t ÷ s elementwise. The leftTensor parameter indicates if the tensor is the left operand. Only scalar types are accepted in s. // Acceptable FuncOpts are: UseUnsafe(), WithReuse(T), WithIncr(T) func (e StdEng) DivScalar(t Tensor, s interface{}, leftTensor bool, opts ...FuncOpt) (retVal Tensor, err error) { - if err = unaryCheck(t, numberTypes); err != nil { + if err = unaryCheck(t, dtype.Number); err != nil { return nil, errors.Wrapf(err, "Div failed") } + if err = scalarDtypeCheck(t, s); err != nil { + return nil, errors.Wrap(err, "Div failed") + } + var reuse DenseTensor var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, _, err = handleFuncOpts(t.Shape(), t.Dtype(), true, opts...); err != nil { + var ctx context.Context + if ctx, reuse, safe, toReuse, incr, _, err = handleFuncOpts(t.Shape(), t.Dtype(), t.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be noopError{}, no need to wrap. + } a := t typ := t.Dtype().Type var ait, bit, iit Iterator var dataA, dataB, dataReuse, scalarHeader *storage.Header - var useIter bool + var useIter, newAlloc bool if leftTensor { - if dataA, dataB, dataReuse, ait, iit, useIter, err = prepDataVS(t, s, reuse); err != nil { + if dataA, dataB, dataReuse, ait, iit, useIter, newAlloc, err = prepDataVS(t, s, reuse); err != nil { return nil, errors.Wrapf(err, opFail, "StdEng.Div") } scalarHeader = dataB } else { - if dataA, dataB, dataReuse, bit, iit, useIter, err = prepDataSV(s, t, reuse); err != nil { + if dataA, dataB, dataReuse, bit, iit, useIter, newAlloc, err = prepDataSV(s, t, reuse); err != nil { return nil, errors.Wrapf(err, opFail, "StdEng.Div") } scalarHeader = dataA @@ -731,6 +811,9 @@ func (e StdEng) DivScalar(t Tensor, s interface{}, leftTensor bool, opts ...Func err = e.E.DivIter(typ, dataA, retVal.hdr(), ait, bit) } } + if newAlloc { + freeScalar(scalarHeader.Raw) + } returnHeader(scalarHeader) return } @@ -745,17 +828,25 @@ func (e StdEng) DivScalar(t Tensor, s interface{}, leftTensor bool, opts ...Func case toReuse && !leftTensor: storage.Copy(typ, dataReuse, dataB) err = e.E.Div(typ, dataA, dataReuse) + if t.Shape().IsScalarEquiv() { + storage.Copy(typ, dataReuse, dataA) + } retVal = reuse case !safe: err = e.E.Div(typ, dataA, dataB) + if t.Shape().IsScalarEquiv() && !leftTensor { + storage.Copy(typ, dataB, dataA) + } retVal = a default: retVal = a.Clone().(Tensor) - if leftTensor { - err = e.E.Div(typ, retVal.hdr(), dataB) - } else { - err = e.E.Div(typ, dataA, retVal.hdr()) + if !leftTensor { + storage.Fill(typ, retVal.hdr(), dataA) } + err = e.E.Div(typ, retVal.hdr(), dataB) + } + if newAlloc { + freeScalar(scalarHeader.Raw) } returnHeader(scalarHeader) return @@ -764,28 +855,36 @@ func (e StdEng) DivScalar(t Tensor, s interface{}, leftTensor bool, opts ...Func // PowScalar performs t ^ s elementwise. The leftTensor parameter indicates if the tensor is the left operand. Only scalar types are accepted in s. // Acceptable FuncOpts are: UseUnsafe(), WithReuse(T), WithIncr(T) func (e StdEng) PowScalar(t Tensor, s interface{}, leftTensor bool, opts ...FuncOpt) (retVal Tensor, err error) { - if err = unaryCheck(t, numberTypes); err != nil { + if err = unaryCheck(t, dtype.Number); err != nil { return nil, errors.Wrapf(err, "Pow failed") } + if err = scalarDtypeCheck(t, s); err != nil { + return nil, errors.Wrap(err, "Pow failed") + } + var reuse DenseTensor var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, _, err = handleFuncOpts(t.Shape(), t.Dtype(), true, opts...); err != nil { + var ctx context.Context + if ctx, reuse, safe, toReuse, incr, _, err = handleFuncOpts(t.Shape(), t.Dtype(), t.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be noopError{}, no need to wrap. + } a := t typ := t.Dtype().Type var ait, bit, iit Iterator var dataA, dataB, dataReuse, scalarHeader *storage.Header - var useIter bool + var useIter, newAlloc bool if leftTensor { - if dataA, dataB, dataReuse, ait, iit, useIter, err = prepDataVS(t, s, reuse); err != nil { + if dataA, dataB, dataReuse, ait, iit, useIter, newAlloc, err = prepDataVS(t, s, reuse); err != nil { return nil, errors.Wrapf(err, opFail, "StdEng.Pow") } scalarHeader = dataB } else { - if dataA, dataB, dataReuse, bit, iit, useIter, err = prepDataSV(s, t, reuse); err != nil { + if dataA, dataB, dataReuse, bit, iit, useIter, newAlloc, err = prepDataSV(s, t, reuse); err != nil { return nil, errors.Wrapf(err, opFail, "StdEng.Pow") } scalarHeader = dataA @@ -819,6 +918,9 @@ func (e StdEng) PowScalar(t Tensor, s interface{}, leftTensor bool, opts ...Func err = e.E.PowIter(typ, dataA, retVal.hdr(), ait, bit) } } + if newAlloc { + freeScalar(scalarHeader.Raw) + } returnHeader(scalarHeader) return } @@ -833,17 +935,25 @@ func (e StdEng) PowScalar(t Tensor, s interface{}, leftTensor bool, opts ...Func case toReuse && !leftTensor: storage.Copy(typ, dataReuse, dataB) err = e.E.Pow(typ, dataA, dataReuse) + if t.Shape().IsScalarEquiv() { + storage.Copy(typ, dataReuse, dataA) + } retVal = reuse case !safe: err = e.E.Pow(typ, dataA, dataB) + if t.Shape().IsScalarEquiv() && !leftTensor { + storage.Copy(typ, dataB, dataA) + } retVal = a default: retVal = a.Clone().(Tensor) - if leftTensor { - err = e.E.Pow(typ, retVal.hdr(), dataB) - } else { - err = e.E.Pow(typ, dataA, retVal.hdr()) + if !leftTensor { + storage.Fill(typ, retVal.hdr(), dataA) } + err = e.E.Pow(typ, retVal.hdr(), dataB) + } + if newAlloc { + freeScalar(scalarHeader.Raw) } returnHeader(scalarHeader) return @@ -852,28 +962,36 @@ func (e StdEng) PowScalar(t Tensor, s interface{}, leftTensor bool, opts ...Func // ModScalar performs t % s elementwise. The leftTensor parameter indicates if the tensor is the left operand. Only scalar types are accepted in s. // Acceptable FuncOpts are: UseUnsafe(), WithReuse(T), WithIncr(T) func (e StdEng) ModScalar(t Tensor, s interface{}, leftTensor bool, opts ...FuncOpt) (retVal Tensor, err error) { - if err = unaryCheck(t, numberTypes); err != nil { + if err = unaryCheck(t, dtype.Number); err != nil { return nil, errors.Wrapf(err, "Mod failed") } + if err = scalarDtypeCheck(t, s); err != nil { + return nil, errors.Wrap(err, "Mod failed") + } + var reuse DenseTensor var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, _, err = handleFuncOpts(t.Shape(), t.Dtype(), true, opts...); err != nil { + var ctx context.Context + if ctx, reuse, safe, toReuse, incr, _, err = handleFuncOpts(t.Shape(), t.Dtype(), t.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be noopError{}, no need to wrap. + } a := t typ := t.Dtype().Type var ait, bit, iit Iterator var dataA, dataB, dataReuse, scalarHeader *storage.Header - var useIter bool + var useIter, newAlloc bool if leftTensor { - if dataA, dataB, dataReuse, ait, iit, useIter, err = prepDataVS(t, s, reuse); err != nil { + if dataA, dataB, dataReuse, ait, iit, useIter, newAlloc, err = prepDataVS(t, s, reuse); err != nil { return nil, errors.Wrapf(err, opFail, "StdEng.Mod") } scalarHeader = dataB } else { - if dataA, dataB, dataReuse, bit, iit, useIter, err = prepDataSV(s, t, reuse); err != nil { + if dataA, dataB, dataReuse, bit, iit, useIter, newAlloc, err = prepDataSV(s, t, reuse); err != nil { return nil, errors.Wrapf(err, opFail, "StdEng.Mod") } scalarHeader = dataA @@ -907,6 +1025,9 @@ func (e StdEng) ModScalar(t Tensor, s interface{}, leftTensor bool, opts ...Func err = e.E.ModIter(typ, dataA, retVal.hdr(), ait, bit) } } + if newAlloc { + freeScalar(scalarHeader.Raw) + } returnHeader(scalarHeader) return } @@ -921,17 +1042,25 @@ func (e StdEng) ModScalar(t Tensor, s interface{}, leftTensor bool, opts ...Func case toReuse && !leftTensor: storage.Copy(typ, dataReuse, dataB) err = e.E.Mod(typ, dataA, dataReuse) + if t.Shape().IsScalarEquiv() { + storage.Copy(typ, dataReuse, dataA) + } retVal = reuse case !safe: err = e.E.Mod(typ, dataA, dataB) + if t.Shape().IsScalarEquiv() && !leftTensor { + storage.Copy(typ, dataB, dataA) + } retVal = a default: retVal = a.Clone().(Tensor) - if leftTensor { - err = e.E.Mod(typ, retVal.hdr(), dataB) - } else { - err = e.E.Mod(typ, dataA, retVal.hdr()) + if !leftTensor { + storage.Fill(typ, retVal.hdr(), dataA) } + err = e.E.Mod(typ, retVal.hdr(), dataB) + } + if newAlloc { + freeScalar(scalarHeader.Raw) } returnHeader(scalarHeader) return diff --git a/defaultengine_cmp.go b/defaultengine_cmp.go index 98f61e1..6a986d3 100644 --- a/defaultengine_cmp.go +++ b/defaultengine_cmp.go @@ -1,29 +1,37 @@ -// Code generated by genlib2. DO NOT EDIT. - package tensor import ( + "context" + "github.com/pkg/errors" + "gorgonia.org/dtype" "gorgonia.org/tensor/internal/storage" ) +// Code generated by genlib2. DO NOT EDIT. + // Gt performs a > b elementwise. Both a and b must have the same shape. // Acceptable FuncOpts are: UseUnsafe(), AsSameType(), WithReuse(). //UseUnsafe() will ensure that the same type is returned. // Tensors used in WithReuse has to have the same Dtype as the return value's Dtype. func (e StdEng) Gt(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err error) { - if err = binaryCheck(a, b, ordTypes); err != nil { + if err = binaryCheck(a, b, dtype.Ord); err != nil { + return nil, errors.Wrapf(err, "Gt failed") } var reuse DenseTensor var safe, same bool - if reuse, safe, _, _, same, err = handleFuncOpts(a.Shape(), a.Dtype(), false, opts...); err != nil { + var ctx context.Context + if ctx, reuse, safe, _, _, same, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), false, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } if !safe { same = true } + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be noopError{}, no need to wrap. + } typ := a.Dtype().Type var dataA, dataB, dataReuse *storage.Header var ait, bit, iit Iterator @@ -66,7 +74,6 @@ func (e StdEng) Gt(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err erro err = e.E.GtIter(typ, dataA, dataB, dataReuse, ait, bit, iit) retVal = reuse } - return } @@ -83,7 +90,6 @@ func (e StdEng) Gt(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err erro err = e.E.Gt(typ, dataA, dataB, dataReuse) retVal = reuse } - return } @@ -92,18 +98,23 @@ func (e StdEng) Gt(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err erro //UseUnsafe() will ensure that the same type is returned. // Tensors used in WithReuse has to have the same Dtype as the return value's Dtype. func (e StdEng) Gte(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err error) { - if err = binaryCheck(a, b, ordTypes); err != nil { + if err = binaryCheck(a, b, dtype.Ord); err != nil { + return nil, errors.Wrapf(err, "Gte failed") } var reuse DenseTensor var safe, same bool - if reuse, safe, _, _, same, err = handleFuncOpts(a.Shape(), a.Dtype(), false, opts...); err != nil { + var ctx context.Context + if ctx, reuse, safe, _, _, same, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), false, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } if !safe { same = true } + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be noopError{}, no need to wrap. + } typ := a.Dtype().Type var dataA, dataB, dataReuse *storage.Header var ait, bit, iit Iterator @@ -146,7 +157,6 @@ func (e StdEng) Gte(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err err err = e.E.GteIter(typ, dataA, dataB, dataReuse, ait, bit, iit) retVal = reuse } - return } @@ -163,7 +173,6 @@ func (e StdEng) Gte(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err err err = e.E.Gte(typ, dataA, dataB, dataReuse) retVal = reuse } - return } @@ -172,18 +181,23 @@ func (e StdEng) Gte(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err err //UseUnsafe() will ensure that the same type is returned. // Tensors used in WithReuse has to have the same Dtype as the return value's Dtype. func (e StdEng) Lt(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err error) { - if err = binaryCheck(a, b, ordTypes); err != nil { + if err = binaryCheck(a, b, dtype.Ord); err != nil { + return nil, errors.Wrapf(err, "Lt failed") } var reuse DenseTensor var safe, same bool - if reuse, safe, _, _, same, err = handleFuncOpts(a.Shape(), a.Dtype(), false, opts...); err != nil { + var ctx context.Context + if ctx, reuse, safe, _, _, same, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), false, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } if !safe { same = true } + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be noopError{}, no need to wrap. + } typ := a.Dtype().Type var dataA, dataB, dataReuse *storage.Header var ait, bit, iit Iterator @@ -226,7 +240,6 @@ func (e StdEng) Lt(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err erro err = e.E.LtIter(typ, dataA, dataB, dataReuse, ait, bit, iit) retVal = reuse } - return } @@ -243,7 +256,6 @@ func (e StdEng) Lt(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err erro err = e.E.Lt(typ, dataA, dataB, dataReuse) retVal = reuse } - return } @@ -252,18 +264,23 @@ func (e StdEng) Lt(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err erro //UseUnsafe() will ensure that the same type is returned. // Tensors used in WithReuse has to have the same Dtype as the return value's Dtype. func (e StdEng) Lte(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err error) { - if err = binaryCheck(a, b, ordTypes); err != nil { + if err = binaryCheck(a, b, dtype.Ord); err != nil { + return nil, errors.Wrapf(err, "Lte failed") } var reuse DenseTensor var safe, same bool - if reuse, safe, _, _, same, err = handleFuncOpts(a.Shape(), a.Dtype(), false, opts...); err != nil { + var ctx context.Context + if ctx, reuse, safe, _, _, same, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), false, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } if !safe { same = true } + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be noopError{}, no need to wrap. + } typ := a.Dtype().Type var dataA, dataB, dataReuse *storage.Header var ait, bit, iit Iterator @@ -306,7 +323,6 @@ func (e StdEng) Lte(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err err err = e.E.LteIter(typ, dataA, dataB, dataReuse, ait, bit, iit) retVal = reuse } - return } @@ -323,7 +339,6 @@ func (e StdEng) Lte(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err err err = e.E.Lte(typ, dataA, dataB, dataReuse) retVal = reuse } - return } @@ -332,18 +347,23 @@ func (e StdEng) Lte(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err err //UseUnsafe() will ensure that the same type is returned. // Tensors used in WithReuse has to have the same Dtype as the return value's Dtype. func (e StdEng) ElEq(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err error) { - if err = binaryCheck(a, b, eqTypes); err != nil { + if err = binaryCheck(a, b, dtype.Eq); err != nil { + return nil, errors.Wrapf(err, "Eq failed") } var reuse DenseTensor var safe, same bool - if reuse, safe, _, _, same, err = handleFuncOpts(a.Shape(), a.Dtype(), false, opts...); err != nil { + var ctx context.Context + if ctx, reuse, safe, _, _, same, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), false, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } if !safe { same = true } + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be noopError{}, no need to wrap. + } typ := a.Dtype().Type var dataA, dataB, dataReuse *storage.Header var ait, bit, iit Iterator @@ -386,7 +406,6 @@ func (e StdEng) ElEq(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err er err = e.E.EqIter(typ, dataA, dataB, dataReuse, ait, bit, iit) retVal = reuse } - return } @@ -403,7 +422,6 @@ func (e StdEng) ElEq(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err er err = e.E.Eq(typ, dataA, dataB, dataReuse) retVal = reuse } - return } @@ -412,18 +430,23 @@ func (e StdEng) ElEq(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err er //UseUnsafe() will ensure that the same type is returned. // Tensors used in WithReuse has to have the same Dtype as the return value's Dtype. func (e StdEng) ElNe(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err error) { - if err = binaryCheck(a, b, eqTypes); err != nil { + if err = binaryCheck(a, b, dtype.Eq); err != nil { + return nil, errors.Wrapf(err, "Ne failed") } var reuse DenseTensor var safe, same bool - if reuse, safe, _, _, same, err = handleFuncOpts(a.Shape(), a.Dtype(), false, opts...); err != nil { + var ctx context.Context + if ctx, reuse, safe, _, _, same, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), false, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } if !safe { same = true } + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be noopError{}, no need to wrap. + } typ := a.Dtype().Type var dataA, dataB, dataReuse *storage.Header var ait, bit, iit Iterator @@ -466,7 +489,6 @@ func (e StdEng) ElNe(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err er err = e.E.NeIter(typ, dataA, dataB, dataReuse, ait, bit, iit) retVal = reuse } - return } @@ -483,7 +505,6 @@ func (e StdEng) ElNe(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err er err = e.E.Ne(typ, dataA, dataB, dataReuse) retVal = reuse } - return } @@ -492,31 +513,39 @@ func (e StdEng) ElNe(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err er // UseUnsafe() will ensure that the same type is returned. // Tensors used in WithReuse has to have the same Dtype as the return value's Dtype. func (e StdEng) GtScalar(t Tensor, s interface{}, leftTensor bool, opts ...FuncOpt) (retVal Tensor, err error) { - if err = unaryCheck(t, ordTypes); err != nil { + if err = unaryCheck(t, dtype.Ord); err != nil { return nil, errors.Wrapf(err, "Gt failed") } + if err = scalarDtypeCheck(t, s); err != nil { + return nil, errors.Wrap(err, "Gt failed") + } + var reuse DenseTensor var safe, same bool - if reuse, safe, _, _, same, err = handleFuncOpts(t.Shape(), t.Dtype(), false, opts...); err != nil { + var ctx context.Context + if ctx, reuse, safe, _, _, same, err = handleFuncOpts(t.Shape(), t.Dtype(), t.DataOrder(), false, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } if !safe { same = true } + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be noopError{}, no need to wrap. + } a := t typ := t.Dtype().Type var ait, bit, iit Iterator var dataA, dataB, dataReuse, scalarHeader *storage.Header - var useIter bool + var useIter, newAlloc bool if leftTensor { - if dataA, dataB, dataReuse, ait, iit, useIter, err = prepDataVS(t, s, reuse); err != nil { + if dataA, dataB, dataReuse, ait, iit, useIter, newAlloc, err = prepDataVS(t, s, reuse); err != nil { return nil, errors.Wrapf(err, opFail, "StdEng.Gt") } scalarHeader = dataB } else { - if dataA, dataB, dataReuse, bit, iit, useIter, err = prepDataSV(s, t, reuse); err != nil { + if dataA, dataB, dataReuse, bit, iit, useIter, newAlloc, err = prepDataSV(s, t, reuse); err != nil { return nil, errors.Wrapf(err, opFail, "StdEng.Gt") } scalarHeader = dataA @@ -559,12 +588,15 @@ func (e StdEng) GtScalar(t Tensor, s interface{}, leftTensor bool, opts ...FuncO err = e.E.GtIter(typ, dataA, dataB, dataReuse, ait, bit, iit) retVal = reuse } + if newAlloc { + freeScalar(scalarHeader.Raw) + } returnHeader(scalarHeader) return } // handle special case where A and B have both len 1 - if dataB.L == 1 && dataB.L == 1 { + if len(dataA.Raw) == int(typ.Size()) && len(dataB.Raw) == int(typ.Size()) { switch { case same && safe && reuse != nil && leftTensor: storage.Copy(typ, dataReuse, dataA) @@ -595,6 +627,9 @@ func (e StdEng) GtScalar(t Tensor, s interface{}, leftTensor bool, opts ...FuncO err = e.E.Gt(typ, dataA, dataB, dataReuse) retVal = reuse } + if newAlloc { + freeScalar(scalarHeader.Raw) + } returnHeader(scalarHeader) return } @@ -604,31 +639,39 @@ func (e StdEng) GtScalar(t Tensor, s interface{}, leftTensor bool, opts ...FuncO // UseUnsafe() will ensure that the same type is returned. // Tensors used in WithReuse has to have the same Dtype as the return value's Dtype. func (e StdEng) GteScalar(t Tensor, s interface{}, leftTensor bool, opts ...FuncOpt) (retVal Tensor, err error) { - if err = unaryCheck(t, ordTypes); err != nil { + if err = unaryCheck(t, dtype.Ord); err != nil { return nil, errors.Wrapf(err, "Gte failed") } + if err = scalarDtypeCheck(t, s); err != nil { + return nil, errors.Wrap(err, "Gte failed") + } + var reuse DenseTensor var safe, same bool - if reuse, safe, _, _, same, err = handleFuncOpts(t.Shape(), t.Dtype(), false, opts...); err != nil { + var ctx context.Context + if ctx, reuse, safe, _, _, same, err = handleFuncOpts(t.Shape(), t.Dtype(), t.DataOrder(), false, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } if !safe { same = true } + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be noopError{}, no need to wrap. + } a := t typ := t.Dtype().Type var ait, bit, iit Iterator var dataA, dataB, dataReuse, scalarHeader *storage.Header - var useIter bool + var useIter, newAlloc bool if leftTensor { - if dataA, dataB, dataReuse, ait, iit, useIter, err = prepDataVS(t, s, reuse); err != nil { + if dataA, dataB, dataReuse, ait, iit, useIter, newAlloc, err = prepDataVS(t, s, reuse); err != nil { return nil, errors.Wrapf(err, opFail, "StdEng.Gte") } scalarHeader = dataB } else { - if dataA, dataB, dataReuse, bit, iit, useIter, err = prepDataSV(s, t, reuse); err != nil { + if dataA, dataB, dataReuse, bit, iit, useIter, newAlloc, err = prepDataSV(s, t, reuse); err != nil { return nil, errors.Wrapf(err, opFail, "StdEng.Gte") } scalarHeader = dataA @@ -671,12 +714,15 @@ func (e StdEng) GteScalar(t Tensor, s interface{}, leftTensor bool, opts ...Func err = e.E.GteIter(typ, dataA, dataB, dataReuse, ait, bit, iit) retVal = reuse } + if newAlloc { + freeScalar(scalarHeader.Raw) + } returnHeader(scalarHeader) return } // handle special case where A and B have both len 1 - if dataB.L == 1 && dataB.L == 1 { + if len(dataA.Raw) == int(typ.Size()) && len(dataB.Raw) == int(typ.Size()) { switch { case same && safe && reuse != nil && leftTensor: storage.Copy(typ, dataReuse, dataA) @@ -707,6 +753,9 @@ func (e StdEng) GteScalar(t Tensor, s interface{}, leftTensor bool, opts ...Func err = e.E.Gte(typ, dataA, dataB, dataReuse) retVal = reuse } + if newAlloc { + freeScalar(scalarHeader.Raw) + } returnHeader(scalarHeader) return } @@ -716,31 +765,39 @@ func (e StdEng) GteScalar(t Tensor, s interface{}, leftTensor bool, opts ...Func // UseUnsafe() will ensure that the same type is returned. // Tensors used in WithReuse has to have the same Dtype as the return value's Dtype. func (e StdEng) LtScalar(t Tensor, s interface{}, leftTensor bool, opts ...FuncOpt) (retVal Tensor, err error) { - if err = unaryCheck(t, ordTypes); err != nil { + if err = unaryCheck(t, dtype.Ord); err != nil { return nil, errors.Wrapf(err, "Lt failed") } + if err = scalarDtypeCheck(t, s); err != nil { + return nil, errors.Wrap(err, "Lt failed") + } + var reuse DenseTensor var safe, same bool - if reuse, safe, _, _, same, err = handleFuncOpts(t.Shape(), t.Dtype(), false, opts...); err != nil { + var ctx context.Context + if ctx, reuse, safe, _, _, same, err = handleFuncOpts(t.Shape(), t.Dtype(), t.DataOrder(), false, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } if !safe { same = true } + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be noopError{}, no need to wrap. + } a := t typ := t.Dtype().Type var ait, bit, iit Iterator var dataA, dataB, dataReuse, scalarHeader *storage.Header - var useIter bool + var useIter, newAlloc bool if leftTensor { - if dataA, dataB, dataReuse, ait, iit, useIter, err = prepDataVS(t, s, reuse); err != nil { + if dataA, dataB, dataReuse, ait, iit, useIter, newAlloc, err = prepDataVS(t, s, reuse); err != nil { return nil, errors.Wrapf(err, opFail, "StdEng.Lt") } scalarHeader = dataB } else { - if dataA, dataB, dataReuse, bit, iit, useIter, err = prepDataSV(s, t, reuse); err != nil { + if dataA, dataB, dataReuse, bit, iit, useIter, newAlloc, err = prepDataSV(s, t, reuse); err != nil { return nil, errors.Wrapf(err, opFail, "StdEng.Lt") } scalarHeader = dataA @@ -783,12 +840,15 @@ func (e StdEng) LtScalar(t Tensor, s interface{}, leftTensor bool, opts ...FuncO err = e.E.LtIter(typ, dataA, dataB, dataReuse, ait, bit, iit) retVal = reuse } + if newAlloc { + freeScalar(scalarHeader.Raw) + } returnHeader(scalarHeader) return } // handle special case where A and B have both len 1 - if dataB.L == 1 && dataB.L == 1 { + if len(dataA.Raw) == int(typ.Size()) && len(dataB.Raw) == int(typ.Size()) { switch { case same && safe && reuse != nil && leftTensor: storage.Copy(typ, dataReuse, dataA) @@ -819,6 +879,9 @@ func (e StdEng) LtScalar(t Tensor, s interface{}, leftTensor bool, opts ...FuncO err = e.E.Lt(typ, dataA, dataB, dataReuse) retVal = reuse } + if newAlloc { + freeScalar(scalarHeader.Raw) + } returnHeader(scalarHeader) return } @@ -828,31 +891,39 @@ func (e StdEng) LtScalar(t Tensor, s interface{}, leftTensor bool, opts ...FuncO // UseUnsafe() will ensure that the same type is returned. // Tensors used in WithReuse has to have the same Dtype as the return value's Dtype. func (e StdEng) LteScalar(t Tensor, s interface{}, leftTensor bool, opts ...FuncOpt) (retVal Tensor, err error) { - if err = unaryCheck(t, ordTypes); err != nil { + if err = unaryCheck(t, dtype.Ord); err != nil { return nil, errors.Wrapf(err, "Lte failed") } + if err = scalarDtypeCheck(t, s); err != nil { + return nil, errors.Wrap(err, "Lte failed") + } + var reuse DenseTensor var safe, same bool - if reuse, safe, _, _, same, err = handleFuncOpts(t.Shape(), t.Dtype(), false, opts...); err != nil { + var ctx context.Context + if ctx, reuse, safe, _, _, same, err = handleFuncOpts(t.Shape(), t.Dtype(), t.DataOrder(), false, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } if !safe { same = true } + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be noopError{}, no need to wrap. + } a := t typ := t.Dtype().Type var ait, bit, iit Iterator var dataA, dataB, dataReuse, scalarHeader *storage.Header - var useIter bool + var useIter, newAlloc bool if leftTensor { - if dataA, dataB, dataReuse, ait, iit, useIter, err = prepDataVS(t, s, reuse); err != nil { + if dataA, dataB, dataReuse, ait, iit, useIter, newAlloc, err = prepDataVS(t, s, reuse); err != nil { return nil, errors.Wrapf(err, opFail, "StdEng.Lte") } scalarHeader = dataB } else { - if dataA, dataB, dataReuse, bit, iit, useIter, err = prepDataSV(s, t, reuse); err != nil { + if dataA, dataB, dataReuse, bit, iit, useIter, newAlloc, err = prepDataSV(s, t, reuse); err != nil { return nil, errors.Wrapf(err, opFail, "StdEng.Lte") } scalarHeader = dataA @@ -895,12 +966,15 @@ func (e StdEng) LteScalar(t Tensor, s interface{}, leftTensor bool, opts ...Func err = e.E.LteIter(typ, dataA, dataB, dataReuse, ait, bit, iit) retVal = reuse } + if newAlloc { + freeScalar(scalarHeader.Raw) + } returnHeader(scalarHeader) return } // handle special case where A and B have both len 1 - if dataB.L == 1 && dataB.L == 1 { + if len(dataA.Raw) == int(typ.Size()) && len(dataB.Raw) == int(typ.Size()) { switch { case same && safe && reuse != nil && leftTensor: storage.Copy(typ, dataReuse, dataA) @@ -931,36 +1005,47 @@ func (e StdEng) LteScalar(t Tensor, s interface{}, leftTensor bool, opts ...Func err = e.E.Lte(typ, dataA, dataB, dataReuse) retVal = reuse } + if newAlloc { + freeScalar(scalarHeader.Raw) + } returnHeader(scalarHeader) return } func (e StdEng) EqScalar(t Tensor, s interface{}, leftTensor bool, opts ...FuncOpt) (retVal Tensor, err error) { - if err = unaryCheck(t, eqTypes); err != nil { + if err = unaryCheck(t, dtype.Eq); err != nil { return nil, errors.Wrapf(err, "Eq failed") } + if err = scalarDtypeCheck(t, s); err != nil { + return nil, errors.Wrap(err, "Eq failed") + } + var reuse DenseTensor var safe, same bool - if reuse, safe, _, _, same, err = handleFuncOpts(t.Shape(), t.Dtype(), false, opts...); err != nil { + var ctx context.Context + if ctx, reuse, safe, _, _, same, err = handleFuncOpts(t.Shape(), t.Dtype(), t.DataOrder(), false, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } if !safe { same = true } + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be noopError{}, no need to wrap. + } a := t typ := t.Dtype().Type var ait, bit, iit Iterator var dataA, dataB, dataReuse, scalarHeader *storage.Header - var useIter bool + var useIter, newAlloc bool if leftTensor { - if dataA, dataB, dataReuse, ait, iit, useIter, err = prepDataVS(t, s, reuse); err != nil { + if dataA, dataB, dataReuse, ait, iit, useIter, newAlloc, err = prepDataVS(t, s, reuse); err != nil { return nil, errors.Wrapf(err, opFail, "StdEng.Eq") } scalarHeader = dataB } else { - if dataA, dataB, dataReuse, bit, iit, useIter, err = prepDataSV(s, t, reuse); err != nil { + if dataA, dataB, dataReuse, bit, iit, useIter, newAlloc, err = prepDataSV(s, t, reuse); err != nil { return nil, errors.Wrapf(err, opFail, "StdEng.Eq") } scalarHeader = dataA @@ -1003,12 +1088,15 @@ func (e StdEng) EqScalar(t Tensor, s interface{}, leftTensor bool, opts ...FuncO err = e.E.EqIter(typ, dataA, dataB, dataReuse, ait, bit, iit) retVal = reuse } + if newAlloc { + freeScalar(scalarHeader.Raw) + } returnHeader(scalarHeader) return } // handle special case where A and B have both len 1 - if dataB.L == 1 && dataB.L == 1 { + if len(dataA.Raw) == int(typ.Size()) && len(dataB.Raw) == int(typ.Size()) { switch { case same && safe && reuse != nil && leftTensor: storage.Copy(typ, dataReuse, dataA) @@ -1039,36 +1127,47 @@ func (e StdEng) EqScalar(t Tensor, s interface{}, leftTensor bool, opts ...FuncO err = e.E.Eq(typ, dataA, dataB, dataReuse) retVal = reuse } + if newAlloc { + freeScalar(scalarHeader.Raw) + } returnHeader(scalarHeader) return } func (e StdEng) NeScalar(t Tensor, s interface{}, leftTensor bool, opts ...FuncOpt) (retVal Tensor, err error) { - if err = unaryCheck(t, eqTypes); err != nil { + if err = unaryCheck(t, dtype.Eq); err != nil { return nil, errors.Wrapf(err, "Ne failed") } + if err = scalarDtypeCheck(t, s); err != nil { + return nil, errors.Wrap(err, "Ne failed") + } + var reuse DenseTensor var safe, same bool - if reuse, safe, _, _, same, err = handleFuncOpts(t.Shape(), t.Dtype(), false, opts...); err != nil { + var ctx context.Context + if ctx, reuse, safe, _, _, same, err = handleFuncOpts(t.Shape(), t.Dtype(), t.DataOrder(), false, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } if !safe { same = true } + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be noopError{}, no need to wrap. + } a := t typ := t.Dtype().Type var ait, bit, iit Iterator var dataA, dataB, dataReuse, scalarHeader *storage.Header - var useIter bool + var useIter, newAlloc bool if leftTensor { - if dataA, dataB, dataReuse, ait, iit, useIter, err = prepDataVS(t, s, reuse); err != nil { + if dataA, dataB, dataReuse, ait, iit, useIter, newAlloc, err = prepDataVS(t, s, reuse); err != nil { return nil, errors.Wrapf(err, opFail, "StdEng.Ne") } scalarHeader = dataB } else { - if dataA, dataB, dataReuse, bit, iit, useIter, err = prepDataSV(s, t, reuse); err != nil { + if dataA, dataB, dataReuse, bit, iit, useIter, newAlloc, err = prepDataSV(s, t, reuse); err != nil { return nil, errors.Wrapf(err, opFail, "StdEng.Ne") } scalarHeader = dataA @@ -1111,12 +1210,15 @@ func (e StdEng) NeScalar(t Tensor, s interface{}, leftTensor bool, opts ...FuncO err = e.E.NeIter(typ, dataA, dataB, dataReuse, ait, bit, iit) retVal = reuse } + if newAlloc { + freeScalar(scalarHeader.Raw) + } returnHeader(scalarHeader) return } // handle special case where A and B have both len 1 - if dataB.L == 1 && dataB.L == 1 { + if len(dataA.Raw) == int(typ.Size()) && len(dataB.Raw) == int(typ.Size()) { switch { case same && safe && reuse != nil && leftTensor: storage.Copy(typ, dataReuse, dataA) @@ -1147,6 +1249,9 @@ func (e StdEng) NeScalar(t Tensor, s interface{}, leftTensor bool, opts ...FuncO err = e.E.Ne(typ, dataA, dataB, dataReuse) retVal = reuse } + if newAlloc { + freeScalar(scalarHeader.Raw) + } returnHeader(scalarHeader) return } diff --git a/defaultengine_linalg.go b/defaultengine_linalg.go index 486e7a0..ca114fd 100644 --- a/defaultengine_linalg.go +++ b/defaultengine_linalg.go @@ -1,21 +1,27 @@ package tensor import ( + "context" "reflect" "github.com/pkg/errors" "gonum.org/v1/gonum/blas" "gonum.org/v1/gonum/mat" + "gorgonia.org/dtype" ) -// Trace returns the trace of a matrix (i.e. the sum of the diagonal elements). If the Tensor provided is not a matrix, it will return an error -func (e StdEng) Trace(t Tensor) (retVal interface{}, err error) { +// Trace returns the trace of a matrix (i.e. the sum of the diagonal elements). If the Tensor provided is not a matrix, it will return an error +func (e StdEng) Trace(ctx context.Context, t Tensor) (retVal interface{}, err error) { + if err := handleCtx(ctx); err != nil { + return nil, err + } + if t.Dims() != 2 { err = errors.Errorf(dimMismatch, 2, t.Dims()) return } - if err = typeclassCheck(t.Dtype(), numberTypes); err != nil { + if err = dtype.TypeClassCheck(t.Dtype(), dtype.Number); err != nil { return nil, errors.Wrap(err, "Trace") } @@ -118,6 +124,12 @@ func (e StdEng) Trace(t Tensor) (retVal interface{}, err error) { } func (e StdEng) Dot(x, y Tensor, opts ...FuncOpt) (retVal Tensor, err error) { + fo := ParseFuncOpts(opts...) + ctx := fo.Context() + if err = handleCtx(ctx); err != nil { + return nil, err + } + if _, ok := x.(DenseTensor); !ok { err = errors.Errorf("Engine only supports working on x that is a DenseTensor. Got %T instead", x) return @@ -138,8 +150,6 @@ func (e StdEng) Dot(x, y Tensor, opts ...FuncOpt) (retVal Tensor, err error) { return } - fo := ParseFuncOpts(opts...) - var reuse, incr DenseTensor if reuse, err = getFloatDenseTensor(fo.reuse); err != nil { err = errors.Wrapf(err, opFail, "Dot - reuse") @@ -211,7 +221,7 @@ func (e StdEng) Dot(x, y Tensor, opts ...FuncOpt) (retVal Tensor, err error) { return } var ret interface{} - if ret, err = e.Inner(a, b); err != nil { + if ret, err = e.Inner(ctx, a, b); err != nil { return nil, errors.Wrapf(err, opFail, "Dot") } return New(FromScalar(ret)), nil @@ -286,13 +296,13 @@ func (e StdEng) Dot(x, y Tensor, opts ...FuncOpt) (retVal Tensor, err error) { var rd *Dense if rd, err = a.TensorMul(b, axesA, axesB); err != nil { - return + panic(err) } if reuse != nil { copyDense(reuse, rd) - ReturnAP(reuse.Info()) - reuse.setAP(rd.Info().Clone()) + ap := rd.Info().Clone() + reuse.setAP(&ap) defer ReturnTensor(rd) // swap out the underlying data and metadata // reuse.data, rd.data = rd.data, reuse.data @@ -308,16 +318,20 @@ func (e StdEng) Dot(x, y Tensor, opts ...FuncOpt) (retVal Tensor, err error) { } // TODO: make it take DenseTensor -func (e StdEng) SVD(a Tensor, uv, full bool) (s, u, v Tensor, err error) { +func (e StdEng) SVD(ctx context.Context, a Tensor, uv, full bool) (s, u, v Tensor, err error) { + if err = handleCtx(ctx); err != nil { + return nil, nil, nil, err + } + var t *Dense var ok bool if err = e.checkAccessible(a); err != nil { - return nil, nil, nil, errors.Wrapf(err, "opFail", "SVD") + return nil, nil, nil, errors.Wrapf(err, "opFail %v", "SVD") } if t, ok = a.(*Dense); !ok { return nil, nil, nil, errors.Errorf("StdEng only performs SVDs for DenseTensors. Got %T instead", a) } - if !isFloat(t.Dtype()) { + if err = dtype.TypeClassCheck(a.Dtype(), dtype.Floats); err != nil { return nil, nil, nil, errors.Errorf("StdEng can only perform SVDs for float64 and float32 type. Got tensor of %v instead", t.Dtype()) } @@ -355,7 +369,7 @@ func (e StdEng) SVD(a Tensor, uv, full bool) (s, u, v Tensor, err error) { // extract values var um, vm mat.Dense - s = recycledDense(Float64, Shape{MinInt(t.Shape()[0], t.Shape()[1])}) + s = recycledDense(Float64, Shape{MinInt(t.Shape()[0], t.Shape()[1])}, WithEngine(e)) svd.Values(s.Data().([]float64)) if uv { svd.UTo(&um) @@ -371,9 +385,13 @@ func (e StdEng) SVD(a Tensor, uv, full bool) (s, u, v Tensor, err error) { // Inner is a thin layer over BLAS's D/Sdot. // It returns a scalar value, wrapped in an interface{}, which is not quite nice. -func (e StdEng) Inner(a, b Tensor) (retVal interface{}, err error) { +func (e StdEng) Inner(ctx context.Context, a, b Tensor) (retVal interface{}, err error) { + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be noopError{}, no need to wrap. + } + var ad, bd DenseTensor - if ad, bd, err = e.checkTwoFloatTensors(a, b); err != nil { + if ad, bd, err = e.checkTwoFloatComplexTensors(a, b); err != nil { return nil, errors.Wrapf(err, opFail, "StdEng.Inner") } @@ -384,6 +402,12 @@ func (e StdEng) Inner(a, b Tensor) (retVal interface{}, err error) { case []float64: B := bd.Float64s() retVal = whichblas.Ddot(len(A), A, 1, B, 1) + case []complex64: + B := bd.Complex64s() + retVal = whichblas.Cdotu(len(A), A, 1, B, 1) + case []complex128: + B := bd.Complex128s() + retVal = whichblas.Zdotu(len(A), A, 1, B, 1) } return } @@ -392,10 +416,14 @@ func (e StdEng) Inner(a, b Tensor) (retVal interface{}, err error) { // Because DGEMV computes: // y = αA * x + βy // we set beta to 0, so we don't have to manually zero out the reused/retval tensor data -func (e StdEng) MatVecMul(a, b, prealloc Tensor) (err error) { +func (e StdEng) MatVecMul(ctx context.Context, a, b, prealloc Tensor) (err error) { + if err := handleCtx(ctx); err != nil { + return err + } + // check all are DenseTensors var ad, bd, pd DenseTensor - if ad, bd, pd, err = e.checkThreeFloatTensors(a, b, prealloc); err != nil { + if ad, bd, pd, err = e.checkThreeFloatComplexTensors(a, b, prealloc); err != nil { return errors.Wrapf(err, opFail, "StdEng.MatVecMul") } @@ -403,12 +431,35 @@ func (e StdEng) MatVecMul(a, b, prealloc Tensor) (err error) { n := ad.oshape()[1] tA := blas.NoTrans - if ad.oldAP() != nil { + do := a.DataOrder() + z := ad.oldAP().IsZero() + + var lda int + switch { + case do.IsRowMajor() && z: + lda = n + case do.IsRowMajor() && !z: tA = blas.Trans + lda = n + case do.IsColMajor() && z: + tA = blas.Trans + lda = m + m, n = n, m + case do.IsColMajor() && !z: + lda = m + m, n = n, m } - lda := ad.ostrides()[0] + incX, incY := 1, 1 // step size + // ASPIRATIONAL TODO: different incX and incY + // TECHNICAL DEBT. TECHDEBT. TECH DEBT + // Example use case: + // log.Printf("a %v %v", ad.Strides(), ad.ostrides()) + // log.Printf("b %v", b.Strides()) + // incX := a.Strides()[0] + // incY = b.Strides()[0] + switch A := ad.Data().(type) { case []float64: x := bd.Float64s() @@ -420,8 +471,18 @@ func (e StdEng) MatVecMul(a, b, prealloc Tensor) (err error) { y := pd.Float32s() alpha, beta := float32(1), float32(0) whichblas.Sgemv(tA, m, n, alpha, A, lda, x, incX, beta, y, incY) + case []complex64: + x := bd.Complex64s() + y := pd.Complex64s() + var alpha, beta complex64 = complex(1, 0), complex(0, 0) + whichblas.Cgemv(tA, m, n, alpha, A, lda, x, incX, beta, y, incY) + case []complex128: + x := bd.Complex128s() + y := pd.Complex128s() + var alpha, beta complex128 = complex(1, 0), complex(0, 0) + whichblas.Zgemv(tA, m, n, alpha, A, lda, x, incX, beta, y, incY) default: - return errors.Errorf(typeNYI, "matVecMul", bd.Data()) + return nyierr(typeNYI, bd.Data()) } return nil @@ -431,56 +492,72 @@ func (e StdEng) MatVecMul(a, b, prealloc Tensor) (err error) { // DGEMM computes: // C = αA * B + βC // To prevent needless zeroing out of the slice, we just set β to 0 -func (e StdEng) MatMul(a, b, prealloc Tensor) (err error) { +func (e StdEng) MatMul(ctx context.Context, a, b, prealloc Tensor) (err error) { + if err := handleCtx(ctx); err != nil { + return err + } + // check all are DenseTensors var ad, bd, pd DenseTensor - if ad, bd, pd, err = e.checkThreeFloatTensors(a, b, prealloc); err != nil { + if ad, bd, pd, err = e.checkThreeFloatComplexTensors(a, b, prealloc); err != nil { return errors.Wrapf(err, opFail, "StdEng.MatMul") } - tA, tB := blas.NoTrans, blas.NoTrans - if ad.oldAP() != nil { - tA = blas.Trans - } - - // Special case if b is (1, N) - if bd.oldAP() != nil || bd.IsRowVec() { - tB = blas.Trans - } + ado := a.DataOrder() + bdo := b.DataOrder() + cdo := prealloc.DataOrder() + // get result shapes. k is the shared dimension + // a is (m, k) + // b is (k, n) + // c is (m, n) var m, n, k int m = ad.Shape()[0] k = ad.Shape()[1] n = bd.Shape()[1] // wrt the strides, we use the original strides, because that's what BLAS needs, instead of calling .Strides() - lda := ad.ostrides()[0] - ldb := bd.ostrides()[0] - ldc := pd.ostrides()[0] + // lda in colmajor = number of rows; + // lda in row major = number of cols + var lda, ldb, ldc int + switch { + case ado.IsColMajor(): + lda = m + case ado.IsRowMajor(): + lda = k + } - // special case: if a is (1, N) x (N, M), then we can just use GEMV - if ad.IsRowVec() { - tB = blas.Trans - if bd.oldAP() != nil { - tB = blas.NoTrans + switch { + case bdo.IsColMajor(): + ldb = bd.Shape()[0] + case bdo.IsRowMajor(): + ldb = n + } + + switch { + case cdo.IsColMajor(): + ldc = prealloc.Shape()[0] + case cdo.IsRowMajor(): + ldc = prealloc.Shape()[1] + } + + // check for trans + tA, tB := blas.NoTrans, blas.NoTrans + if !ad.oldAP().IsZero() { + tA = blas.Trans + if ado.IsRowMajor() { + lda = m + } else { + lda = k } - m = bd.Shape()[0] - n = bd.Shape()[1] - switch A := ad.Data().(type) { - case []float64: - B := bd.Float64s() - C := pd.Float64s() - alpha, beta := float64(1), float64(0) - whichblas.Dgemv(tB, m, n, alpha, B, ldb, A, lda, beta, C, ldc) - case []float32: - B := bd.Float32s() - C := pd.Float32s() - alpha, beta := float32(1), float32(0) - whichblas.Sgemv(tB, m, n, alpha, B, ldb, A, lda, beta, C, ldc) - default: - return errors.Errorf(typeNYI, "matMul a is row vec", ad.Data()) + } + if !bd.oldAP().IsZero() { + tB = blas.Trans + if bdo.IsRowMajor() { + ldb = bd.Shape()[0] + } else { + ldb = bd.Shape()[1] } - return } switch A := ad.Data().(type) { @@ -488,34 +565,93 @@ func (e StdEng) MatMul(a, b, prealloc Tensor) (err error) { B := bd.Float64s() C := pd.Float64s() alpha, beta := float64(1), float64(0) - whichblas.Dgemm(tA, tB, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc) + if ado.IsColMajor() && bdo.IsColMajor() { + whichblas.Dgemm(tA, tB, n, m, k, alpha, B, ldb, A, lda, beta, C, ldc) + } else { + whichblas.Dgemm(tA, tB, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc) + } case []float32: B := bd.Float32s() C := pd.Float32s() alpha, beta := float32(1), float32(0) - whichblas.Sgemm(tA, tB, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc) + if ado.IsColMajor() && bdo.IsColMajor() { + whichblas.Sgemm(tA, tB, n, m, k, alpha, B, ldb, A, lda, beta, C, ldc) + } else { + whichblas.Sgemm(tA, tB, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc) + } + case []complex64: + B := bd.Complex64s() + C := pd.Complex64s() + var alpha, beta complex64 = complex(1, 0), complex(0, 0) + if ado.IsColMajor() && bdo.IsColMajor() { + whichblas.Cgemm(tA, tB, n, m, k, alpha, B, ldb, A, lda, beta, C, ldc) + } else { + whichblas.Cgemm(tA, tB, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc) + } + case []complex128: + B := bd.Complex128s() + C := pd.Complex128s() + var alpha, beta complex128 = complex(1, 0), complex(0, 0) + if ado.IsColMajor() && bdo.IsColMajor() { + whichblas.Zgemm(tA, tB, n, m, k, alpha, B, ldb, A, lda, beta, C, ldc) + } else { + whichblas.Zgemm(tA, tB, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc) + } default: - return errors.Errorf(typeNYI, "matMul", ad.Data()) + return nyierr(typeNYI, ad.Data()) + } return } // Outer is a thin wrapper over S/Dger -func (e StdEng) Outer(a, b, prealloc Tensor) (err error) { +func (e StdEng) Outer(ctx context.Context, a, b, prealloc Tensor) (err error) { + if err = handleCtx(ctx); err != nil { + return err + } + // check all are DenseTensors var ad, bd, pd DenseTensor - if ad, bd, pd, err = e.checkThreeFloatTensors(a, b, prealloc); err != nil { + if ad, bd, pd, err = e.checkThreeFloatComplexTensors(a, b, prealloc); err != nil { return errors.Wrapf(err, opFail, "StdEng.Outer") } m := ad.Size() n := bd.Size() + pdo := pd.DataOrder() // the stride of a Vector is always going to be [1], // incX := t.Strides()[0] // incY := other.Strides()[0] incX, incY := 1, 1 - lda := pd.Strides()[0] + // lda := pd.Strides()[0] + var lda int + switch { + case pdo.IsColMajor(): + aShape := a.Shape().Clone() + bShape := b.Shape().Clone() + if err = a.Reshape(aShape[0], 1); err != nil { + return err + } + if err = b.Reshape(1, bShape[0]); err != nil { + return err + } + + if err = e.MatMul(ctx, a, b, prealloc); err != nil { + return err + } + + if err = b.Reshape(bShape...); err != nil { + return + } + if err = a.Reshape(aShape...); err != nil { + return + } + return nil + + case pdo.IsRowMajor(): + lda = pd.Shape()[1] + } switch x := ad.Data().(type) { case []float64: @@ -528,8 +664,18 @@ func (e StdEng) Outer(a, b, prealloc Tensor) (err error) { A := pd.Float32s() alpha := float32(1) whichblas.Sger(m, n, alpha, x, incX, y, incY, A, lda) + case []complex64: + y := bd.Complex64s() + A := pd.Complex64s() + var alpha complex64 = complex(1, 0) + whichblas.Cgeru(m, n, alpha, x, incX, y, incY, A, lda) + case []complex128: + y := bd.Complex128s() + A := pd.Complex128s() + var alpha complex128 = complex(1, 0) + whichblas.Zgeru(m, n, alpha, x, incX, y, incY, A, lda) default: - return errors.Errorf(typeNYI, "outer", b.Data()) + return nyierr(typeNYI, b.Data()) } return nil } @@ -559,13 +705,13 @@ func (e StdEng) checkTwoFloatTensors(a, b Tensor) (ad, bd DenseTensor, err error func (e StdEng) checkThreeFloatTensors(a, b, ret Tensor) (ad, bd, retVal DenseTensor, err error) { if err = e.checkAccessible(a); err != nil { - return nil, nil, nil, errors.Wrap(err, "checkTwoTensors: a is not accessible") + return nil, nil, nil, errors.Wrap(err, "checkThreeTensors: a is not accessible") } if err = e.checkAccessible(b); err != nil { - return nil, nil, nil, errors.Wrap(err, "checkTwoTensors: a is not accessible") + return nil, nil, nil, errors.Wrap(err, "checkThreeTensors: a is not accessible") } if err = e.checkAccessible(ret); err != nil { - return nil, nil, nil, errors.Wrap(err, "checkTwoTensors: ret is not accessible") + return nil, nil, nil, errors.Wrap(err, "checkThreeTensors: ret is not accessible") } if a.Dtype() != b.Dtype() || b.Dtype() != ret.Dtype() { @@ -583,3 +729,51 @@ func (e StdEng) checkThreeFloatTensors(a, b, ret Tensor) (ad, bd, retVal DenseTe } return } + +func (e StdEng) checkTwoFloatComplexTensors(a, b Tensor) (ad, bd DenseTensor, err error) { + if err = e.checkAccessible(a); err != nil { + return nil, nil, errors.Wrap(err, "checkTwoTensors: a is not accessible") + } + if err = e.checkAccessible(b); err != nil { + return nil, nil, errors.Wrap(err, "checkTwoTensors: a is not accessible") + } + + if a.Dtype() != b.Dtype() { + return nil, nil, errors.New("Expected a and b to have the same Dtype") + } + + if ad, err = getFloatComplexDenseTensor(a); err != nil { + return nil, nil, errors.Wrap(err, "checkTwoTensors expects a to be be a DenseTensor") + } + if bd, err = getFloatComplexDenseTensor(b); err != nil { + return nil, nil, errors.Wrap(err, "checkTwoTensors expects b to be be a DenseTensor") + } + return +} + +func (e StdEng) checkThreeFloatComplexTensors(a, b, ret Tensor) (ad, bd, retVal DenseTensor, err error) { + if err = e.checkAccessible(a); err != nil { + return nil, nil, nil, errors.Wrap(err, "checkThreeTensors: a is not accessible") + } + if err = e.checkAccessible(b); err != nil { + return nil, nil, nil, errors.Wrap(err, "checkThreeTensors: a is not accessible") + } + if err = e.checkAccessible(ret); err != nil { + return nil, nil, nil, errors.Wrap(err, "checkThreeTensors: ret is not accessible") + } + + if a.Dtype() != b.Dtype() || b.Dtype() != ret.Dtype() { + return nil, nil, nil, errors.New("Expected a and b and retVal all to have the same Dtype") + } + + if ad, err = getFloatComplexDenseTensor(a); err != nil { + return nil, nil, nil, errors.Wrap(err, "checkTwoTensors expects a to be be a DenseTensor") + } + if bd, err = getFloatComplexDenseTensor(b); err != nil { + return nil, nil, nil, errors.Wrap(err, "checkTwoTensors expects b to be be a DenseTensor") + } + if retVal, err = getFloatComplexDenseTensor(ret); err != nil { + return nil, nil, nil, errors.Wrap(err, "checkTwoTensors expects retVal to be be a DenseTensor") + } + return +} diff --git a/defaultengine_mapreduce.go b/defaultengine_mapreduce.go index 203a839..2c7ccd4 100644 --- a/defaultengine_mapreduce.go +++ b/defaultengine_mapreduce.go @@ -1,35 +1,41 @@ package tensor import ( + "context" "reflect" + "sort" "github.com/pkg/errors" + "gorgonia.org/tensor/internal/execution" "gorgonia.org/tensor/internal/storage" ) func (e StdEng) Map(fn interface{}, a Tensor, opts ...FuncOpt) (retVal Tensor, err error) { - if err = unaryCheck(a, nil); err != nil { + if err = unaryCheck(a, nilTC); err != nil { err = errors.Wrap(err, "Failed Map()") return } + if _, ok := a.(DenseTensor); !ok { + return nil, errors.Errorf("StdEng's Map method only supports dense tensors for now. Please put in a Pull Request to support other forms of Tensors. The file is: defaultengine_mapreduce.go") + } var reuse DenseTensor var safe, _, incr bool - if reuse, safe, _, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), true, opts...); err != nil { + var ctx context.Context + if ctx, reuse, safe, _, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { return } + if err = handleCtx(ctx); err != nil { + return nil, err // will be noopError{}, no need to wrap. + } switch { case safe && reuse == nil: // create reuse - if v, ok := a.(View); ok { - if v.IsMaterializable() { - reuse = v.Materialize().(DenseTensor) - } else { - reuse = v.Clone().(DenseTensor) - } + if v, ok := a.(View); ok && v.IsMaterializable() { + reuse = v.Materialize().(DenseTensor) } else { - reuse = New(Of(a.Dtype()), WithShape(a.Shape().Clone()...)) + reuse = a.Clone().(DenseTensor) } case reuse != nil: if !reuse.IsNativelyAccessible() { @@ -73,7 +79,7 @@ func (e StdEng) Map(fn interface{}, a Tensor, opts ...FuncOpt) (retVal Tensor, e // SET RETVAL switch { case reuse != nil: - if err = reuseCheckShape(reuse, a.Shape()); err != nil { + if err = checkFixShape(reuse, a.Shape()); err != nil { err = errors.Wrapf(err, "Reuse shape check failed") return } @@ -102,18 +108,18 @@ func (e StdEng) Reduce(fn interface{}, a Tensor, axis int, defaultValue interfac // actual call out to the internal engine switch { - case (axis == 0 && at.DataOrder().isRowMajor()) || ((axis == lastAxis || axis == len(a.Shape())-1) && at.DataOrder().isColMajor()): + case (axis == 0 && at.DataOrder().IsRowMajor()) || ((axis == lastAxis || axis == len(a.Shape())-1) && at.DataOrder().IsColMajor()): var size, split int - if at.DataOrder().isColMajor() { + if at.DataOrder().IsColMajor() { return nil, errors.Errorf("NYI: colmajor") } size = a.Shape()[0] split = a.DataSize() / size storage.CopySliced(typ, dataReuse, 0, split, dataA, 0, split) err = e.E.ReduceFirst(typ, dataA, dataReuse, split, size, fn) - case (axis == lastAxis && at.DataOrder().isRowMajor()) || (axis == 0 && at.DataOrder().isColMajor()): + case (axis == lastAxis && at.DataOrder().IsRowMajor()) || (axis == 0 && at.DataOrder().IsColMajor()): var dimSize int - if at.DataOrder().isColMajor() { + if at.DataOrder().IsColMajor() { return nil, errors.Errorf("NYI: colmajor") } dimSize = a.Shape()[axis] @@ -147,18 +153,18 @@ func (e StdEng) OptimizedReduce(a Tensor, axis int, firstFn, lastFn, defaultFn, // actual call out to the internal engine switch { - case (axis == 0 && at.DataOrder().isRowMajor()) || ((axis == lastAxis || axis == len(a.Shape())-1) && at.DataOrder().isColMajor()): + case (axis == 0 && at.DataOrder().IsRowMajor()) || ((axis == lastAxis || axis == len(a.Shape())-1) && at.DataOrder().IsColMajor()): var size, split int - if at.DataOrder().isColMajor() { + if at.DataOrder().IsColMajor() { return nil, errors.Errorf("NYI: colmajor") } size = a.Shape()[0] split = a.DataSize() / size storage.CopySliced(typ, dataReuse, 0, split, dataA, 0, split) err = e.E.ReduceFirst(typ, dataA, dataReuse, split, size, firstFn) - case (axis == lastAxis && at.DataOrder().isRowMajor()) || (axis == 0 && at.DataOrder().isColMajor()): + case (axis == lastAxis && at.DataOrder().IsRowMajor()) || (axis == 0 && at.DataOrder().IsColMajor()): var dimSize int - if at.DataOrder().isColMajor() { + if at.DataOrder().IsColMajor() { return nil, errors.Errorf("NYI: colmajor") } dimSize = a.Shape()[axis] @@ -175,100 +181,48 @@ func (e StdEng) OptimizedReduce(a Tensor, axis int, firstFn, lastFn, defaultFn, return } -func (e StdEng) Sum(a Tensor, along ...int) (retVal Tensor, err error) { - switch at := a.(type) { - case *Dense: - hdr := at.hdr() - typ := at.t.Type - monotonic, incr1 := IsMonotonicInts(along) // if both are true, then it means all axes are accounted for, then it'll return a scalar value - if (monotonic && incr1 && len(along) == a.Dims()) || len(along) == 0 { - var ret interface{} - if ret, err = execution.MonotonicSum(typ, hdr); err != nil { - return - } - return New(FromScalar(ret)), nil - } - var firstFn, lastFn, defaultFn interface{} - if firstFn, lastFn, defaultFn, err = execution.SumMethods(typ); err != nil { - return - } - defaultVal := reflect.Zero(typ).Interface() - - retVal = a - prev := -1 - dims := len(retVal.Shape()) - - for _, axis := range along { - if prev == -1 { - prev = axis - } - if axis > prev { - axis-- - } - - if axis >= dims { - err = errors.Errorf(dimMismatch, retVal.Dims(), axis) - return - } - if retVal, err = e.OptimizedReduce(retVal, axis, firstFn, lastFn, defaultFn, defaultVal); err != nil { - return - } - } - return +func (e StdEng) Sum(ctx context.Context, a Tensor, along ...int) (retVal Tensor, err error) { + if err = handleCtx(ctx); err != nil { + return nil, err + } - default: - return nil, errors.Errorf("Cannot perform Sum on %T", a) + a2 := a + if v, ok := a.(View); ok && v.IsMaterializable() { + a2 = v.Materialize() } + return e.reduce("Sum", execution.MonotonicSum, execution.SumMethods, a2, along...) } -func (e StdEng) Min(a Tensor, along ...int) (retVal Tensor, err error) { - switch at := a.(type) { - case *Dense: - hdr := at.hdr() - typ := at.t.Type - monotonic, incr1 := IsMonotonicInts(along) // if both are true, then it means all axes are accounted for, then it'll return a scalar value - if (monotonic && incr1 && len(along) == a.Dims()) || len(along) == 0 { - var ret interface{} - if ret, err = execution.MonotonicMin(typ, hdr); err != nil { - return - } - return New(FromScalar(ret)), nil - } - var firstFn, lastFn, defaultFn interface{} - if firstFn, lastFn, defaultFn, err = execution.MinMethods(typ); err != nil { - return - } - defaultVal := reflect.Zero(typ).Interface() - - retVal = a - prev := -1 - dims := len(retVal.Shape()) - - for _, axis := range along { - if prev == -1 { - prev = axis - } - if axis > prev { - axis-- - } +func (e StdEng) Min(ctx context.Context, a Tensor, along ...int) (retVal Tensor, err error) { + if err = handleCtx(ctx); err != nil { + return nil, err + } - if axis >= dims { - err = errors.Errorf(dimMismatch, retVal.Dims(), axis) - return - } + a2 := a + if v, ok := a.(View); ok && v.IsMaterializable() { + a2 = v.Materialize() + } + return e.reduce("Min", execution.MonotonicMin, execution.MinMethods, a2, along...) +} - if retVal, err = e.OptimizedReduce(retVal, axis, firstFn, lastFn, defaultFn, defaultVal); err != nil { - return - } - } - return +func (e StdEng) Max(ctx context.Context, a Tensor, along ...int) (retVal Tensor, err error) { + if err = handleCtx(ctx); err != nil { + return nil, err + } - default: - return nil, errors.Errorf("Cannot perform Min on %T", a) + a2 := a + if v, ok := a.(View); ok && v.IsMaterializable() { + a2 = v.Materialize() } + return e.reduce("Max", execution.MonotonicMax, execution.MaxMethods, a2, along...) } -func (e StdEng) Max(a Tensor, along ...int) (retVal Tensor, err error) { +func (e StdEng) reduce( + op string, + monotonicMethod func(t reflect.Type, a *storage.Header) (interface{}, error), + methods func(t reflect.Type) (interface{}, interface{}, interface{}, error), + a Tensor, + along ...int) (retVal Tensor, err error) { switch at := a.(type) { case *Dense: hdr := at.hdr() @@ -276,30 +230,25 @@ func (e StdEng) Max(a Tensor, along ...int) (retVal Tensor, err error) { monotonic, incr1 := IsMonotonicInts(along) // if both are true, then it means all axes are accounted for, then it'll return a scalar value if (monotonic && incr1 && len(along) == a.Dims()) || len(along) == 0 { var ret interface{} - if ret, err = execution.MonotonicMax(typ, hdr); err != nil { + if ret, err = monotonicMethod(typ, hdr); err != nil { return } return New(FromScalar(ret)), nil } var firstFn, lastFn, defaultFn interface{} - if firstFn, lastFn, defaultFn, err = execution.MaxMethods(typ); err != nil { + if firstFn, lastFn, defaultFn, err = methods(typ); err != nil { return } defaultVal := reflect.Zero(typ).Interface() retVal = a - prev := -1 - dims := len(retVal.Shape()) + dimsReduced := 0 + sort.Slice(along, func(i, j int) bool { return along[i] < along[j] }) for _, axis := range along { - if prev == -1 { - prev = axis - } - if axis > prev { - axis-- - } - - if axis >= dims { + axis -= dimsReduced + dimsReduced++ + if axis >= retVal.Dims() { err = errors.Errorf(dimMismatch, retVal.Dims(), axis) return } @@ -311,8 +260,9 @@ func (e StdEng) Max(a Tensor, along ...int) (retVal Tensor, err error) { return default: - return nil, errors.Errorf("Cannot perform Max on %T", a) + return nil, errors.Errorf("Cannot perform %s on %T", op, a) } + } func (StdEng) prepReduce(a Tensor, axis int, opts ...FuncOpt) (at, reuse DenseTensor, dataA, dataReuse *storage.Header, err error) { @@ -321,17 +271,21 @@ func (StdEng) prepReduce(a Tensor, axis int, opts ...FuncOpt) (at, reuse DenseTe return } - if err = unaryCheck(a, nil); err != nil { + if err = unaryCheck(a, nilTC); err != nil { err = errors.Wrap(err, "prepReduce failed") return } // FUNC PREP var safe bool - if reuse, safe, _, _, _, err = handleFuncOpts(a.Shape(), a.Dtype(), false, opts...); err != nil { + var ctx context.Context + if ctx, reuse, safe, _, _, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), false, opts...); err != nil { err = errors.Wrap(err, "Unable to prep unary tensor") return } + if err = handleCtx(ctx); err != nil { + return + } var newShape Shape for i, s := range a.Shape() { diff --git a/defaultengine_matop_gatherscatter.go b/defaultengine_matop_gatherscatter.go new file mode 100644 index 0000000..b28d4ea --- /dev/null +++ b/defaultengine_matop_gatherscatter.go @@ -0,0 +1,200 @@ +package tensor + +import ( + "sync" + + "github.com/pkg/errors" +) + +func (e StdEng) Scatter(a, indices Tensor, opts ...FuncOpt) (retVal Tensor, err error) { + fo := ParseFuncOpts(opts...) + reuse := fo.Reuse() + + maxT, err := Max(indices) + if err != nil { + return nil, errors.Wrapf(err, "Cannot find the max of the indices") + } + max, ok := maxT.Data().(int) + if !ok { + return nil, errors.Errorf("Indices must be of ints. Got %v of %T instead", maxT.Data(), maxT.Data()) + } + + // expected shape + shp := indices.Shape().Clone() + shp[len(shp)-1] = max + 1 + + switch { + case reuse == nil && fo.Safe(): + // create reuse + reuse = New(WithShape(shp...), Of(a.Dtype())) + case reuse == nil && !fo.Safe(): + // check shape of `a` - the last dim of a must be at least max+1 + if a.Shape()[a.Dims()-1] < max+1 { + return nil, errors.Errorf("Cannot Scatter - the last dim of `a` %v must be at least %v, which is the maximum value of the indices + 1", a.Shape(), max+1) + } + reuse = a + case reuse != nil: + // check shape of `reuse` - last dim of `reuse` must at least be as large as max+1 + if reuse.Shape()[reuse.Dims()-1] < max+1 { + return nil, errors.Errorf("Cannot Scatter. The last dim of `reuse` %v must be at least %v, which is the maximum value off the indices + 1", reuse.Shape(), max+1) + } + } + + oldShape := a.Shape().Clone() + oldIndicesShape := a.Shape().Clone() + reuseOldShape := reuse.Shape().Clone() + defer func() { a.Reshape(oldShape...); indices.Reshape(oldIndicesShape...); reuse.Reshape(reuseOldShape...) }() + + switch { + case indices.Shape().IsVectorLike(): + idx := indices.Data().([]int) + _ = idx + // TODO + default: + // THIS IS ROW MAJOR ONLY + // THIS IS DENSE TENSOR ONLY + + a := a.(DenseTensor) + indices := indices.(DenseTensor) + reuse := reuse.(DenseTensor) + + // reshape everything into a matrix + a.Reshape(asMat(a.Shape(), a.Dims()-1, true)...) + indices.Reshape(asMat(indices.Shape(), indices.Dims()-1, true)...) + reuse.Reshape(asMat(reuse.Shape(), reuse.Dims()-1, true)...) + + // check that indices' shape[0] is <= a.Shape[0] + if indices.Shape()[0] > a.Shape()[0] { + // something is wrong + return nil, errors.Errorf("Cannot scatter") + } + + // now they are all matrices, we can iterate thru them + var ps []iteratorPair + for i := 0; i < indices.Shape()[0]; i++ { + ait := AxialIteratorFromDense(a, 0, i, true) + iit := AxialIteratorFromDense(indices, 0, i, true) + + ps = append(ps, iteratorPair{ait, iit, i}) + } + + errChan := make(chan error, len(ps)) + var wg sync.WaitGroup + for i := range ps { + wg.Add(1) + // note: be careful not to use `for i, p := range ps` + // and then use `go p.coiter`. + // This is because `p` is would not be captured by `go`, + // thus every `p` would be `ps[len(ps)-1]`. + go ps[i].coiter(a, indices, reuse, errChan, &wg) + } + wg.Wait() + close(errChan) + err = <-errChan // maybe get ALL the errors from errChan? + return reuse, err + + } + + panic("unreachable") +} + +type iteratorPair struct { + a *AxialIterator + idx *AxialIterator + axis int +} + +func (it *iteratorPair) coiter(a, indices, reuse DenseTensor, errChan chan error, wg *sync.WaitGroup) { + defer wg.Done() + ii, err := it.idx.Start() + if err != nil { + if err = handleNoOp(err); err != nil { + errChan <- err + } + return + } + + iData := indices.Data().([]int) + retStride := reuse.Strides()[0] + switch { + case a.Dtype() == Float64 && reuse.Dtype() == Float64: + aData := a.Data().([]float64) + rData := reuse.Data().([]float64) + + var ai, ii int + if ai, err = it.a.Start(); err != nil { + goto reterr + } + if ii, err = it.idx.Start(); err != nil { + goto reterr + } + for { + + idx := iData[ii] + v := aData[ai] + + rData[it.axis*retStride+idx] = v + + if it.a.Done() || it.idx.Done() { + break + } + if ai, err = it.a.Next(); err != nil { + break + } + if ii, err = it.idx.Next(); err != nil { + break + } + } + case a.Dtype() == Float32 && reuse.Dtype() == Float32: + aData := a.Data().([]float32) + rData := reuse.Data().([]float32) + + var ai, ii int + if ai, err = it.a.Start(); err != nil { + goto reterr + } + if ii, err = it.idx.Start(); err != nil { + goto reterr + } + for { + + idx := iData[ii] + v := aData[ai] + + rData[it.axis*retStride+idx] = v + + if it.a.Done() || it.idx.Done() { + break + } + if ai, err = it.a.Next(); err != nil { + break + } + if ii, err = it.idx.Next(); err != nil { + break + } + } + + default: + + // generic + for ai, err := it.a.Start(); err == nil; ai, err = it.a.Next() { + if it.idx.Done() { + break + } + idx := iData[ii] + v := a.arrPtr().Get(ai) + reuse.Set(it.axis*retStride+idx, v) + + if ii, err = it.idx.Next(); err != nil { + break + } + } + } + +reterr: + if err = handleNoOp(err); err != nil { + errChan <- err + return + } + +} diff --git a/defaultengine_matop_misc.go b/defaultengine_matop_misc.go index 23607c6..56641d3 100644 --- a/defaultengine_matop_misc.go +++ b/defaultengine_matop_misc.go @@ -1,131 +1,441 @@ -package tensor - -import "github.com/pkg/errors" - -func (e StdEng) Repeat(t Tensor, axis int, repeats ...int) (Tensor, error) { - switch tt := t.(type) { - case DenseTensor: - return e.denseRepeat(tt, axis, repeats) - default: - return nil, errors.Errorf("NYI") - } -} - -func (StdEng) denseRepeat(t DenseTensor, axis int, repeats []int) (retVal DenseTensor, err error) { - var newShape Shape - var size int - if newShape, repeats, size, err = t.Shape().Repeat(axis, repeats...); err != nil { - return nil, errors.Wrap(err, "Unable to get repeated shape") - } - - if axis == AllAxes { - axis = 0 - } - - d := recycledDense(t.Dtype(), newShape) - - var outers int - if t.IsScalar() { - outers = 1 - } else { - outers = ProdInts(t.Shape()[0:axis]) - if outers == 0 { - outers = 1 - } - } - - var stride, newStride int - if newShape.IsVector() || t.IsVector() { - stride = 1 // special case because CalcStrides() will return []int{1} as the strides for a vector - } else { - stride = t.ostrides()[axis] - } - - if newShape.IsVector() { - newStride = 1 - } else { - newStride = d.ostrides()[axis] - } - - var destStart, srcStart int - for i := 0; i < outers; i++ { - for j := 0; j < size; j++ { - var tmp int - tmp = repeats[j] - - for k := 0; k < tmp; k++ { - if srcStart >= t.len() || destStart+stride > d.len() { - break - } - copyDenseSliced(d, destStart, d.len(), t, srcStart, t.len()) - destStart += newStride - } - srcStart += stride - } - } - return d, nil -} - -func (e StdEng) Concat(t Tensor, axis int, others ...Tensor) (retVal Tensor, err error) { - switch tt := t.(type) { - case DenseTensor: - var denses []DenseTensor - if denses, err = tensorsToDenseTensors(others); err != nil { - return nil, errors.Wrap(err, "Concat failed") - } - return e.denseConcat(tt, axis, denses) - default: - return nil, errors.Errorf("NYI") - } -} - -func (e StdEng) denseConcat(a DenseTensor, axis int, Ts []DenseTensor) (DenseTensor, error) { - ss := make([]Shape, len(Ts)) - var err error - var isMasked bool - for i, T := range Ts { - ss[i] = T.Shape() - if mt, ok := T.(MaskedTensor); ok { - isMasked = isMasked || mt.IsMasked() - } - } - - var newShape Shape - if newShape, err = a.Shape().Concat(axis, ss...); err != nil { - return nil, errors.Wrap(err, "Unable to find new shape that results from concatenation") - } - - retVal := recycledDense(a.Dtype(), newShape) - if isMasked { - retVal.makeMask() - } - - all := make([]DenseTensor, len(Ts)+1) - all[0] = a - copy(all[1:], Ts) - - // special case - var start, end int - - for _, T := range all { - end += T.Shape()[axis] - slices := make([]Slice, axis+1) - slices[axis] = makeRS(start, end) - - var v *Dense - if v, err = sliceDense(retVal, slices...); err != nil { - return nil, errors.Wrap(err, "Unable to slice DenseTensor while performing denseConcat") - } - - if v.IsVector() && T.IsMatrix() && axis == 0 { - v.reshape(v.shape[0], 1) - } - - if err = assignArray(v, T); err != nil { - return nil, errors.Wrap(err, "Unable to assignArray in denseConcat") - } - start = end - } - - return retVal, nil -} +package tensor + +import ( + "context" + + "github.com/pkg/errors" + "gorgonia.org/dtype" + "gorgonia.org/tensor/internal/storage" + + "gorgonia.org/shapes" +) + +var ( + _ Diager = StdEng{} +) + +type fastcopier interface { + fastCopyDenseRepeat(t DenseTensor, d *Dense, outers, size, stride, newStride int, repeats []int) error +} + +// Repeat ... +func (e StdEng) Repeat(ctx context.Context, t Tensor, axis int, repeats ...int) (Tensor, error) { + if err := handleCtx(ctx); err != nil { + return nil, err + } + + switch tt := t.(type) { + case DenseTensor: + newShape, newRepeats, newAxis, size, err := e.denseRepeatCheck(t, axis, repeats) + if err != nil { + return nil, err + } + rr := recycledDense(t.Dtype(), newShape, WithEngine(StdEng{})) + return e.denseRepeat(tt, rr, newShape, newAxis, size, newRepeats) + default: + return nil, nyierr(typeNYI, t) + } +} + +// RepeatReuse is like Repeat, but with a provided reuse Tensor. The reuseTensor must be of the same type as the input t. +func (e StdEng) RepeatReuse(ctx context.Context, t Tensor, reuse Tensor, axis int, repeats ...int) (Tensor, error) { + if err := handleCtx(ctx); err != nil { + return nil, err + } + + switch tt := t.(type) { + case DenseTensor: + newShape, newRepeats, newAxis, size, err := e.denseRepeatCheck(t, axis, repeats) + if err != nil { + return nil, err + } + + rr, ok := reuse.(DenseTensor) + if !ok { + return nil, errors.Errorf("t is a DenseTensor but reuse is of %T", reuse) + } + if !reuse.Shape().Eq(newShape) { + return nil, errors.Errorf("Reuse shape is %v. Expected shape is %v", reuse.Shape(), newShape) + } + return e.denseRepeat(tt, rr, newShape, newAxis, size, newRepeats) + default: + return nil, nyierr(typeNYI, t) + } +} + +func (StdEng) denseRepeatCheck(t Tensor, axis int, repeats []int) (newShape Shape, newRepeats []int, newAxis, size int, err error) { + var newShapelike shapes.Shapelike + if newShapelike, newRepeats, size, err = t.Shape().Repeat(shapes.Axis(axis), repeats...); err != nil { + return nil, nil, -1, -1, errors.Wrap(err, "Unable to get repeated shape") + } + newShape = newShapelike.(Shape) + newAxis = axis + if axis == AllAxes { + newAxis = 0 + } + + return +} + +func (StdEng) denseRepeat(t, reuse DenseTensor, newShape Shape, axis, size int, repeats []int) (retVal DenseTensor, err error) { + d, err := assertDense(reuse) + if err != nil { + return nil, errors.Wrapf(err, "Repeat reuse is not a *Dense") + } + var outers int + if t.IsScalar() { + outers = 1 + } else { + outers = ProdInts(t.Shape()[0:axis]) + } + + var stride, newStride int + if newShape.IsVector() || t.IsVector() { + stride = 1 // special case because CalcStrides() will return []int{1} as the strides for a vector + } else { + stride = t.ostrides()[axis] + } + + if newShape.IsVector() { + newStride = 1 + } else { + newStride = d.ostrides()[axis] + } + + var destStart, srcStart int + // fastCopy is not bypassing the copyDenseSliced method to populate the output tensor + var fastCopy bool + var fce fastcopier + // we need an engine for fastCopying... + e := t.Engine() + // e can never be nil. Error would have occurred elsewhere + var ok bool + if fce, ok = e.(fastcopier); ok { + fastCopy = true + } + + // In this case, let's not implement the fast copy to keep the code readable + if ms, ok := t.(MaskedTensor); ok && ms.IsMasked() { + fastCopy = false + } + + // if d is not a fastcopier, then we also cannot use fast copy + if _, ok := d.Engine().(fastcopier); !ok { + fastCopy = false + } + + if fastCopy { + if err := fce.fastCopyDenseRepeat(t, d, outers, size, stride, newStride, repeats); err != nil { + return nil, err + } + return d, nil + } + + for i := 0; i < outers; i++ { + for j := 0; j < size; j++ { + var tmp int + tmp = repeats[j] + + for k := 0; k < tmp; k++ { + if srcStart >= t.len() || destStart+stride > d.len() { + break + } + copyDenseSliced(d, destStart, d.len(), t, srcStart, t.len()) + destStart += newStride + } + srcStart += stride + } + } + return d, nil +} + +func (e StdEng) fastCopyDenseRepeat(src DenseTensor, dest *Dense, outers, size, stride, newStride int, repeats []int) error { + sarr := src.arr() + darr := dest.arr() + + var destStart, srcStart int + for i := 0; i < outers; i++ { + // faster shortcut for common case. + // + // Consider a case where: + // a := ⎡ 1 ⎤ + // ⎢ 2 ⎥ + // ⎢ 3 ⎥ + // ⎣ 4 ⎦ + // a has a shape of (4, 1). it is a *Dense. + // + // Now assume we want to repeat it on axis 1, 3 times. We want to repeat it into `b`, + // which is already allocated and zeroed, as shown below + // + // b := ⎡ 0 0 0 ⎤ + // ⎢ 0 0 0 ⎥ + // ⎢ 0 0 0 ⎥ + // ⎣ 0 0 0 ⎦ + // + // Now, both `a` and `b` have a stride of 1. + // + // The desired result is: + // b := ⎡ 1 1 1 ⎤ + // ⎢ 2 2 2 ⎥ + // ⎢ 3 3 3 ⎥ + // ⎣ 4 4 4 ⎦ + /// + // Observe that this is simply broadcasting (copying) a[0] (a scalar value) to the row b[0], and so on and so forth. + // This can be done without knowing the full type - we simply copy the bytes over. + if stride == 1 && newStride == 1 { + for sz := 0; sz < size; sz++ { + tmp := repeats[sz] + + // first we get the bounds of the src and the dest + // the srcStart and destStart are the indices assuming a flat array of []T + // we need to get the byte slice equivalent. + bSrcStart := srcStart * int(sarr.t.Size()) + bSrcEnd := (srcStart + stride) * int(sarr.t.Size()) + bDestStart := destStart * int(darr.t.Size()) + bDestEnd := (destStart + tmp) * int(darr.t.Size()) + + // then we get the data as a slice of raw bytes + sBS := sarr.Header.Raw + dBS := darr.Header.Raw + + // recall that len(src) < len(dest) + // it's easier to understand if we define the ranges. + // Less prone to errors. + sRange := sBS[bSrcStart:bSrcEnd] + dRange := dBS[bDestStart:bDestEnd] + + // finally we copy things. + for i := 0; i < len(dRange); i += len(sRange) { + copy(dRange[i:], sRange) + } + srcStart += stride + destStart += tmp + } + + // we can straightaway broadcast + continue + } + + for j := 0; j < size; j++ { + var tmp int + tmp = repeats[j] + var tSlice array + + tSlice = sarr.slice(srcStart, src.len()) + + for k := 0; k < tmp; k++ { + if srcStart >= src.len() || destStart+stride > dest.len() { + break + } + + dSlice := darr.slice(destStart, destStart+newStride) + + // THIS IS AN OPTIMIZATION. REVISIT WHEN NEEDED. + storage.Copy(dSlice.t.Type, &dSlice.Header, &tSlice.Header) + + destStart += newStride + } + srcStart += stride + } + } + return nil +} + +// Concat tensors +func (e StdEng) Concat(ctx context.Context, t Tensor, axis int, others ...Tensor) (retVal Tensor, err error) { + if err := handleCtx(ctx); err != nil { + return nil, err + } + + switch tt := t.(type) { + case DenseTensor: + var denses []DenseTensor + if denses, err = tensorsToDenseTensors(others); err != nil { + return nil, errors.Wrap(err, "Concat failed") + } + return e.denseConcat(tt, axis, denses) + default: + return nil, nyierr(typeNYI, t) + } +} + +func (e StdEng) denseConcat(a DenseTensor, axis int, Ts []DenseTensor) (DenseTensor, error) { + ss := make([]Shape, len(Ts)) + var err error + var isMasked bool + for i, T := range Ts { + ss[i] = T.Shape() + if mt, ok := T.(MaskedTensor); ok { + isMasked = isMasked || mt.IsMasked() + } + } + + var newShapelike shapes.Shapelike + if newShapelike, err = a.Shape().Concat(shapes.Axis(axis), shapes.ShapesToShapelikes(ss)...); err != nil { + return nil, errors.Wrap(err, "Unable to find new shape that results from concatenation") + } + newShape := newShapelike.(Shape) + + retVal := recycledDense(a.Dtype(), newShape, WithEngine(e)) + if isMasked { + retVal.makeMask() + } + + all := make([]DenseTensor, len(Ts)+1) + all[0] = a + copy(all[1:], Ts) + + // TODO: OPIMIZATION + // When (axis == 0 && a is row major and all others is row major) || (axis == last axis of A && all tensors are colmajor) + // just flat copy + // + + // isOuter is true when the axis is the outermost axis + // isInner is true when the axis is the inner most axis + isOuter := axis == 0 + isInner := axis == (a.Shape().Dims() - 1) + + // special case + var start, end int + for _, T := range all { + end += T.Shape()[axis] + slices := make([]Slice, axis+1) + slices[axis] = makeRS(start, end) + + var v *Dense + if v, err = sliceDense(retVal, slices...); err != nil { + return nil, errors.Wrap(err, "Unable to slice DenseTensor while performing denseConcat") + } + + // keep dims after slicing + switch { + case v.IsVector() && T.IsMatrix() && axis == 0: + v.reshape(v.shape[0], 1) + case T.IsRowVec() && axis == 0: + T.reshape(T.Shape()[1]) + case v.Shape().IsScalarEquiv() && T.Shape().IsScalarEquiv(): + copyArray(v.arrPtr(), T.arrPtr()) + if mt, ok := T.(MaskedTensor); ok { + copy(v.mask, mt.Mask()) + } + start = end + continue + default: + diff := retVal.Shape().Dims() - v.Shape().Dims() + if diff > 0 && isOuter { + newShape := make(Shape, v.Shape().Dims()+diff) + for i := 0; i < diff; i++ { + newShape[i] = 1 + } + copy(newShape[diff:], v.Shape()) + v.reshape(newShape...) + } else if diff > 0 && isInner { + newShape := v.Shape().Clone() + newStrides := v.strides + for i := 0; i < diff; i++ { + newShape = append(newShape, 1) + newStrides = append(newStrides, 1) + } + v.shape = newShape + v.strides = newStrides + } else if T.Shape()[axis] == 1 { + if err := v.unsqueeze(axis); err != nil { + return nil, errors.Wrapf(err, "Unable to keep dims after slicing a shape %v on axis %d where the size is 1", T.Shape(), axis) + } + } + } + + var vmask, Tmask []bool + vmask = v.mask + v.mask = nil + if mt, ok := T.(MaskedTensor); ok && mt.IsMasked() { + Tmask = mt.Mask() + mt.SetMask(nil) + + } + + if err = assignArray(v, T); err != nil { + return nil, errors.Wrap(err, "Unable to assignArray in denseConcat") + } + // if it's a masked tensor, we copy the mask as well + if Tmask != nil { + if vmask != nil { + if cap(vmask) < len(Tmask) { + vmask2 := make([]bool, len(Tmask)) + copy(vmask2, vmask) + vmask = vmask2 + } + copy(vmask, Tmask) + v.SetMask(vmask) + } + // mt.SetMask(Tmask) + } + + start = end + } + + return retVal, nil +} + +// Diag ... +func (e StdEng) Diag(ctx context.Context, t Tensor) (retVal Tensor, err error) { + if err := handleCtx(ctx); err != nil { + return nil, err + } + + a, ok := t.(DenseTensor) + if !ok { + return nil, errors.Errorf("StdEng only works with DenseTensor for Diagonal()") + } + + if a.Dims() != 2 { + err = errors.Errorf(dimMismatch, 2, a.Dims()) + return + } + + if err = dtype.TypeClassCheck(a.Dtype(), dtype.Number); err != nil { + return nil, errors.Wrap(err, "Diagonal") + } + + rstride := a.Strides()[0] + cstride := a.Strides()[1] + + r := a.Shape()[0] + c := a.Shape()[1] + + m := MinInt(r, c) + stride := rstride + cstride + + b := a.Clone().(DenseTensor) + b.Zero() + + switch a.rtype().Size() { + case 1: + bdata := b.hdr().Uint8s() + adata := a.hdr().Uint8s() + for i := 0; i < m; i++ { + bdata[i] = adata[i*stride] + } + case 2: + bdata := b.hdr().Uint16s() + adata := a.hdr().Uint16s() + for i := 0; i < m; i++ { + bdata[i] = adata[i*stride] + } + case 4: + bdata := b.hdr().Uint32s() + adata := a.hdr().Uint32s() + for i := 0; i < m; i++ { + bdata[i] = adata[i*stride] + } + case 8: + bdata := b.hdr().Uint64s() + adata := a.hdr().Uint64s() + for i := 0; i < m; i++ { + bdata[i] = adata[i*stride] + } + default: + return nil, nyierr(typeNYI, "Arbitrary-sized .Diag()", t) + } + return b, nil +} diff --git a/defaultengine_matop_stack.go b/defaultengine_matop_stack.go index 1a43a7e..33c148d 100644 --- a/defaultengine_matop_stack.go +++ b/defaultengine_matop_stack.go @@ -1,13 +1,23 @@ package tensor import ( + "context" + "github.com/pkg/errors" - "gorgonia.org/tensor/internal/storage" ) // This file contains code for the execution engine to stack tensors -func (e StdEng) StackDense(t DenseTensor, axis int, others ...DenseTensor) (retVal DenseTensor, err error) { +var ( + // _ Stacker = StdEng{} + _ DenseStacker = StdEng{} +) + +func (e StdEng) StackDense(ctx context.Context, t DenseTensor, axis int, others ...DenseTensor) (retVal DenseTensor, err error) { + if err := handleCtx(ctx); err != nil { + return nil, err + } + opdims := t.Dims() if axis >= opdims+1 { err = errors.Errorf(dimMismatch, opdims+1, axis) @@ -28,15 +38,13 @@ func (e StdEng) StackDense(t DenseTensor, axis int, others ...DenseTensor) (retV info := t.Info() var newStrides []int - if info.o.isColMajor() { - newStrides = newShape.calcStridesColMajor() + if info.o.IsColMajor() { + newStrides = CalcStridesColMajor(newShape) } else { - newStrides = newShape.calcStrides() + newStrides = CalcStrides(newShape) } - ap := NewAP(newShape, newStrides) - ap.o = info.o - ap.Δ = info.Δ + ap := MakeAP(newShape, newStrides, info.o, info.Δ) allNoMat := !t.RequiresIterator() for _, ot := range others { @@ -46,8 +54,7 @@ func (e StdEng) StackDense(t DenseTensor, axis int, others ...DenseTensor) (retV } retVal = recycledDense(t.Dtype(), ap.Shape(), WithEngine(e)) - ReturnAP(retVal.Info()) - retVal.setAP(ap) + retVal.setAP(&ap) // the "viewStack" method is the more generalized method // and will work for all Tensors, regardless of whether it's a view @@ -369,7 +376,7 @@ func (e StdEng) doViewStack8(t, retVal DenseTensor, axisStride, batches int, it func (e StdEng) doViewStackArbitrary(t, retVal DenseTensor, axisStride, batches int, it Iterator, others []DenseTensor, its []Iterator) (err error) { dt := t.Dtype() - data := storage.AsByteSlice(retVal.hdr(), dt.Type)[:0] + data := retVal.hdr().Raw[:0] // truncate to 0 size := int(dt.Size()) var mask []bool var retIsMasked bool @@ -388,8 +395,7 @@ func (e StdEng) doViewStackArbitrary(t, retVal DenseTensor, axisStride, batches tmask = mt.Mask() isMasked = mt.IsMasked() } - dt := t.Dtype() - bs := storage.AsByteSlice(t.hdr(), dt.Type) + bs := t.hdr().Raw for last = 0; last < axisStride; last++ { id, err := it.Next() diff --git a/defaultengine_matop_transpose.go b/defaultengine_matop_transpose.go index d2beba6..7ca63cb 100644 --- a/defaultengine_matop_transpose.go +++ b/defaultengine_matop_transpose.go @@ -1,259 +1,161 @@ -package tensor - -import ( - "github.com/pkg/errors" - "gorgonia.org/tensor/internal/storage" -) - -func (e StdEng) Transpose(a Tensor, expStrides []int) error { - if !a.IsNativelyAccessible() { - return errors.Errorf("Cannot Transpose() on non-natively accessible tensor") - } - if dt, ok := a.(DenseTensor); ok { - e.denseTranspose(dt, expStrides) - return nil - } - return errors.Errorf("Tranpose for tensor of %T not supported", a) -} - -func (e StdEng) denseTranspose(a DenseTensor, expStrides []int) { - if a.rtype() == String.Type { - e.denseTransposeString(a, expStrides) - return - } - - switch a.rtype().Size() { - case 1: - e.denseTranspose1(a, expStrides) - case 2: - e.denseTranspose2(a, expStrides) - case 4: - e.denseTranspose4(a, expStrides) - case 8: - e.denseTranspose8(a, expStrides) - default: - e.denseTransposeArbitrary(a, expStrides) - } -} - -func (e StdEng) denseTranspose1(a DenseTensor, expStrides []int) { - axes := a.transposeAxes() - size := a.len() - - // first we'll create a bit-map to track which elements have been moved to their correct places - track := NewBitMap(size) - track.Set(0) - track.Set(size - 1) // first and last element of a transposedon't change - - var saved, tmp byte - var i int - - data := a.hdr().Uint8s() - for i = 1; ; { - dest := a.transposeIndex(i, axes, expStrides) - - if track.IsSet(i) && track.IsSet(dest) { - data[i] = saved - saved = 0 - for i < size && track.IsSet(i) { - i++ - } - if i >= size { - break - } - continue - } - track.Set(i) - tmp = data[i] - data[i] = saved - saved = tmp - - i = dest - } -} - -func (e StdEng) denseTranspose2(a DenseTensor, expStrides []int) { - axes := a.transposeAxes() - size := a.len() - - // first we'll create a bit-map to track which elements have been moved to their correct places - track := NewBitMap(size) - track.Set(0) - track.Set(size - 1) // first and last element of a transposedon't change - - var saved, tmp uint16 - var i int - - data := a.hdr().Uint16s() - for i = 1; ; { - dest := a.transposeIndex(i, axes, expStrides) - - if track.IsSet(i) && track.IsSet(dest) { - data[i] = saved - saved = 0 - for i < size && track.IsSet(i) { - i++ - } - if i >= size { - break - } - continue - } - track.Set(i) - tmp = data[i] - data[i] = saved - saved = tmp - - i = dest - } -} - -func (e StdEng) denseTranspose4(a DenseTensor, expStrides []int) { - axes := a.transposeAxes() - size := a.len() - - // first we'll create a bit-map to track which elements have been moved to their correct places - track := NewBitMap(size) - track.Set(0) - track.Set(size - 1) // first and last element of a transposedon't change - - var saved, tmp uint32 - var i int - - data := a.hdr().Uint32s() - for i = 1; ; { - dest := a.transposeIndex(i, axes, expStrides) - - if track.IsSet(i) && track.IsSet(dest) { - data[i] = saved - saved = 0 - for i < size && track.IsSet(i) { - i++ - } - if i >= size { - break - } - continue - } - track.Set(i) - tmp = data[i] - data[i] = saved - saved = tmp - - i = dest - } -} - -func (e StdEng) denseTranspose8(a DenseTensor, expStrides []int) { - axes := a.transposeAxes() - size := a.len() - - // first we'll create a bit-map to track which elements have been moved to their correct places - track := NewBitMap(size) - track.Set(0) - track.Set(size - 1) // first and last element of a transposedon't change - - var saved, tmp uint64 - var i int - - data := a.hdr().Uint64s() - for i = 1; ; { - dest := a.transposeIndex(i, axes, expStrides) - - if track.IsSet(i) && track.IsSet(dest) { - data[i] = saved - saved = 0 - for i < size && track.IsSet(i) { - i++ - } - if i >= size { - break - } - continue - } - track.Set(i) - tmp = data[i] - data[i] = saved - saved = tmp - - i = dest - } -} - -func (e StdEng) denseTransposeString(a DenseTensor, expStrides []int) { - axes := a.transposeAxes() - size := a.len() - - // first we'll create a bit-map to track which elements have been moved to their correct places - track := NewBitMap(size) - track.Set(0) - track.Set(size - 1) // first and last element of a transposedon't change - - var saved, tmp string - var i int - - data := a.hdr().Strings() - for i = 1; ; { - dest := a.transposeIndex(i, axes, expStrides) - - if track.IsSet(i) && track.IsSet(dest) { - data[i] = saved - saved = "" - for i < size && track.IsSet(i) { - i++ - } - if i >= size { - break - } - continue - } - track.Set(i) - tmp = data[i] - data[i] = saved - saved = tmp - - i = dest - } -} - -func (e StdEng) denseTransposeArbitrary(a DenseTensor, expStrides []int) { - axes := a.transposeAxes() - size := a.len() - rtype := a.rtype() - typeSize := int(rtype.Size()) - - // first we'll create a bit-map to track which elements have been moved to their correct places - track := NewBitMap(size) - track.Set(0) - track.Set(size - 1) // first and last element of a transposedon't change - - saved := make([]byte, typeSize, typeSize) - tmp := make([]byte, typeSize, typeSize) - var i int - - data := storage.AsByteSlice(a.hdr(), rtype) - for i = 1; ; { - dest := a.transposeIndex(i, axes, expStrides) - start := typeSize * i - - if track.IsSet(i) && track.IsSet(dest) { - copy(data[start:start+typeSize], saved) - for i := range saved { - saved[i] = 0 - } - for i < size && track.IsSet(i) { - i++ - } - if i >= size { - break - } - continue - } - track.Set(i) - copy(tmp, data[start:start+typeSize]) - copy(data[start:start+typeSize], saved) - saved = tmp - - i = dest - } -} +// +build !inplacetranspose + +package tensor + +import ( + "context" + + "github.com/pkg/errors" +) + +func (e StdEng) Transpose(ctx context.Context, a Tensor, expStrides []int) error { + if err := handleCtx(ctx); err != nil { + return err + } + + if !a.IsNativelyAccessible() { + return errors.Errorf("Cannot Transpose() on non-natively accessible tensor") + } + if dt, ok := a.(DenseTensor); ok { + e.denseTranspose(dt, expStrides) + return nil + } + return errors.Errorf("Tranpose for tensor of %T not supported", a) +} + +func (e StdEng) denseTranspose(a DenseTensor, expStrides []int) { + if a.rtype() == String.Type { + e.denseTransposeString(a, expStrides) + return + } + + e.transposeMask(a) + + switch a.rtype().Size() { + case 1: + e.denseTranspose1(a, expStrides) + case 2: + e.denseTranspose2(a, expStrides) + case 4: + e.denseTranspose4(a, expStrides) + case 8: + e.denseTranspose8(a, expStrides) + default: + e.denseTransposeArbitrary(a, expStrides) + } +} + +func (e StdEng) transposeMask(a DenseTensor) { + if !a.(*Dense).IsMasked() { + return + } + + orig := a.(*Dense).Mask() + tmp := make([]bool, len(orig)) + + it := newFlatIterator(a.Info()) + var j int + for i, err := it.Next(); err == nil; i, err = it.Next() { + tmp[j] = orig[i] + j++ + } + copy(orig, tmp) +} + +func (e StdEng) denseTranspose1(a DenseTensor, expStrides []int) { + var tmpArr array + e.makeArray(&tmpArr, a.Dtype(), a.Size()) + u8s := tmpArr.Uint8s() + + orig := a.hdr().Uint8s() + it := newFlatIterator(a.Info()) + var j int + for i, err := it.Next(); err == nil; i, err = it.Next() { + u8s[j] = orig[i] + j++ + } + copy(orig, u8s) +} + +func (e StdEng) denseTranspose2(a DenseTensor, expStrides []int) { + var tmpArr array + e.makeArray(&tmpArr, a.Dtype(), a.Size()) + u16s := tmpArr.Uint16s() + + orig := a.hdr().Uint16s() + it := newFlatIterator(a.Info()) + var j int + for i, err := it.Next(); err == nil; i, err = it.Next() { + u16s[j] = orig[i] + j++ + } + copy(orig, u16s) +} + +func (e StdEng) denseTranspose4(a DenseTensor, expStrides []int) { + var tmpArr array + e.makeArray(&tmpArr, a.Dtype(), a.Size()) + u32s := tmpArr.Uint32s() + + orig := a.hdr().Uint32s() + it := newFlatIterator(a.Info()) + var j int + for i, err := it.Next(); err == nil; i, err = it.Next() { + u32s[j] = orig[i] + j++ + } + copy(orig, u32s) +} + +func (e StdEng) denseTranspose8(a DenseTensor, expStrides []int) { + var tmpArr array + e.makeArray(&tmpArr, a.Dtype(), a.Size()) + u64s := tmpArr.Uint64s() + + orig := a.hdr().Uint64s() + it := newFlatIterator(a.Info()) + var j int + for i, err := it.Next(); err == nil; i, err = it.Next() { + u64s[j] = orig[i] + j++ + } + copy(orig, u64s) +} + +func (e StdEng) denseTransposeString(a DenseTensor, expStrides []int) { + var tmpArr array + e.makeArray(&tmpArr, a.Dtype(), a.Size()) + strs := tmpArr.Strings() + + orig := a.hdr().Strings() + it := newFlatIterator(a.Info()) + var j int + for i, err := it.Next(); err == nil; i, err = it.Next() { + strs[j] = orig[i] + j++ + } + copy(orig, strs) +} + +func (e StdEng) denseTransposeArbitrary(a DenseTensor, expStrides []int) { + rtype := a.rtype() + typeSize := int(rtype.Size()) + var tmpArr array + e.makeArray(&tmpArr, a.Dtype(), a.Size()) + // arbs := storage.AsByteSlice(tmpArr.hdr(), rtype) + arbs := tmpArr.byteSlice() + + orig := a.hdr().Raw + it := newFlatIterator(a.Info()) + var j int + for i, err := it.Next(); err == nil; i, err = it.Next() { + srcStart := i * typeSize + srcEnd := srcStart + typeSize + dstStart := j * typeSize + dstEnd := dstStart + typeSize + + copy(arbs[dstStart:dstEnd], orig[srcStart:srcEnd]) + j++ + } + copy(orig, arbs) +} diff --git a/defaultengine_matop_transpose_inplace.go b/defaultengine_matop_transpose_inplace.go new file mode 100644 index 0000000..8d1d5f3 --- /dev/null +++ b/defaultengine_matop_transpose_inplace.go @@ -0,0 +1,326 @@ +// +build inplacetranspose + +package tensor + +import ( + "context" + + "github.com/pkg/errors" +) + +func (e StdEng) Transpose(ctx context.Context, a Tensor, expStrides []int) error { + if err := handleCtx(ctx); err != nil { + return err + } + + if !a.IsNativelyAccessible() { + return errors.Errorf("Cannot Transpose() on non-natively accessible tensor") + } + if dt, ok := a.(DenseTensor); ok { + e.denseTranspose(dt, expStrides) + return nil + } + return errors.Errorf("Tranpose for tensor of %T not supported", a) +} + +func (e StdEng) denseTranspose(a DenseTensor, expStrides []int) { + if a.rtype() == String.Type { + e.denseTransposeString(a, expStrides) + return + } + + e.transposeMask(a) + + switch a.rtype().Size() { + case 1: + e.denseTranspose1(a, expStrides) + case 2: + e.denseTranspose2(a, expStrides) + case 4: + e.denseTranspose4(a, expStrides) + case 8: + e.denseTranspose8(a, expStrides) + default: + e.denseTransposeArbitrary(a, expStrides) + } +} + +func (e StdEng) transposeMask(a DenseTensor) { + if !a.(*Dense).IsMasked() { + return + } + + shape := a.Shape() + if len(shape) != 2 { + // TODO(poopoothegorilla): currently only two dimensions are implemented + return + } + n, m := shape[0], shape[1] + mask := a.(*Dense).Mask() + size := len(mask) + + track := NewBitMap(size) + track.Set(0) + track.Set(size - 1) + + for i := 0; i < size; i++ { + srci := i + if track.IsSet(srci) { + continue + } + srcv := mask[srci] + for { + oc := srci % n + or := (srci - oc) / n + desti := oc*m + or + + if track.IsSet(desti) { + break + } + track.Set(desti) + destv := mask[desti] + mask[desti] = srcv + srci = desti + srcv = destv + } + } +} + +func (e StdEng) denseTranspose1(a DenseTensor, expStrides []int) { + axes := a.transposeAxes() + size := a.len() + + // first we'll create a bit-map to track which elements have been moved to their correct places + track := NewBitMap(size) + track.Set(0) + track.Set(size - 1) // first and last element of a transposedon't change + + var saved, tmp byte + var i int + + data := a.hdr().Uint8s() + if len(data) < 4 { + return + } + for i = 1; ; { + dest := a.transposeIndex(i, axes, expStrides) + + if track.IsSet(i) && track.IsSet(dest) { + data[i] = saved + saved = 0 + for i < size && track.IsSet(i) { + i++ + } + if i >= size { + break + } + continue + } + track.Set(i) + tmp = data[i] + data[i] = saved + saved = tmp + + i = dest + } +} + +func (e StdEng) denseTranspose2(a DenseTensor, expStrides []int) { + axes := a.transposeAxes() + size := a.len() + + // first we'll create a bit-map to track which elements have been moved to their correct places + track := NewBitMap(size) + track.Set(0) + track.Set(size - 1) // first and last element of a transposedon't change + + var saved, tmp uint16 + var i int + + data := a.hdr().Uint16s() + if len(data) < 4 { + return + } + for i = 1; ; { + dest := a.transposeIndex(i, axes, expStrides) + + if track.IsSet(i) && track.IsSet(dest) { + data[i] = saved + saved = 0 + for i < size && track.IsSet(i) { + i++ + } + if i >= size { + break + } + continue + } + track.Set(i) + tmp = data[i] + data[i] = saved + saved = tmp + + i = dest + } +} + +func (e StdEng) denseTranspose4(a DenseTensor, expStrides []int) { + axes := a.transposeAxes() + size := a.len() + + // first we'll create a bit-map to track which elements have been moved to their correct places + track := NewBitMap(size) + track.Set(0) + track.Set(size - 1) // first and last element of a transposedon't change + + var saved, tmp uint32 + var i int + + data := a.hdr().Uint32s() + if len(data) < 4 { + return + } + for i = 1; ; { + dest := a.transposeIndex(i, axes, expStrides) + + if track.IsSet(i) && track.IsSet(dest) { + data[i] = saved + saved = 0 + for i < size && track.IsSet(i) { + i++ + } + if i >= size { + break + } + continue + } + track.Set(i) + tmp = data[i] + data[i] = saved + saved = tmp + + i = dest + } +} + +func (e StdEng) denseTranspose8(a DenseTensor, expStrides []int) { + axes := a.transposeAxes() + size := a.len() + + // first we'll create a bit-map to track which elements have been moved to their correct places + track := NewBitMap(size) + track.Set(0) + track.Set(size - 1) // first and last element of a transposedon't change + + var saved, tmp uint64 + var i int + + data := a.hdr().Uint64s() + if len(data) < 4 { + return + } + for i = 1; ; { + dest := a.transposeIndex(i, axes, expStrides) + if track.IsSet(i) && track.IsSet(dest) { + data[i] = saved + saved = 0 + for i < size && track.IsSet(i) { + i++ + } + if i >= size { + break + } + continue + } + track.Set(i) + // log.Printf("i: %d start %d, end %d | tmp %v saved %v", i, start, end, tmp, saved) + tmp = data[i] + data[i] = saved + saved = tmp + + i = dest + } +} + +func (e StdEng) denseTransposeString(a DenseTensor, expStrides []int) { + axes := a.transposeAxes() + size := a.len() + + // first we'll create a bit-map to track which elements have been moved to their correct places + track := NewBitMap(size) + track.Set(0) + track.Set(size - 1) // first and last element of a transposedon't change + + var saved, tmp string + var i int + + data := a.hdr().Strings() + if len(data) < 4 { + return + } + for i = 1; ; { + dest := a.transposeIndex(i, axes, expStrides) + + if track.IsSet(i) && track.IsSet(dest) { + data[i] = saved + saved = "" + for i < size && track.IsSet(i) { + i++ + } + if i >= size { + break + } + continue + } + track.Set(i) + tmp = data[i] + data[i] = saved + saved = tmp + + i = dest + } +} + +func (e StdEng) denseTransposeArbitrary(a DenseTensor, expStrides []int) { + axes := a.transposeAxes() + size := a.len() + rtype := a.rtype() + typeSize := int(rtype.Size()) + + // first we'll create a bit-map to track which elements have been moved to their correct places + track := NewBitMap(size) + track.Set(0) + track.Set(size - 1) // first and last element of a transposedon't change + + saved := make([]byte, typeSize, typeSize) + tmp := make([]byte, typeSize, typeSize) + var i int + data := a.arr().Raw + if len(data) < 4*typeSize { + return + } + for i = 1; ; { + dest := a.transposeIndex(i, axes, expStrides) + start := typeSize * i + end := start + typeSize + + if track.IsSet(i) && track.IsSet(dest) { + copy(data[start:end], saved) + for i := range saved { + saved[i] = 0 + } + for i < size && track.IsSet(i) { + i++ + } + if i >= size { + break + } + continue + } + track.Set(i) + copy(tmp, data[start:end]) + copy(data[start:end], saved) + copy(saved, tmp) + i = dest + } +} diff --git a/defaultengine_minmax.go b/defaultengine_minmax.go new file mode 100644 index 0000000..a16cbf0 --- /dev/null +++ b/defaultengine_minmax.go @@ -0,0 +1,370 @@ +package tensor + +import ( + "context" + + "github.com/pkg/errors" + "gorgonia.org/dtype" + "gorgonia.org/tensor/internal/storage" +) + +// Code generated by genlib2. DO NOT EDIT. + +var ( + _ MinBetweener = StdEng{} + _ MaxBetweener = StdEng{} +) + +func (e StdEng) MinBetween(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err error) { + if err = binaryCheck(a, b, dtype.Ord); err != nil { + + return nil, errors.Wrapf(err, "MinBetween failed") + } + + var reuse DenseTensor + var safe bool + var ctx context.Context + if ctx, reuse, safe, _, _, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { + return nil, errors.Wrap(err, "Unable to handle funcOpts") + } + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be noopError{}, no need to wrap. + } + typ := a.Dtype().Type + var dataA, dataB, dataReuse *storage.Header + var ait, bit, iit Iterator + var useIter, swap bool + if dataA, dataB, dataReuse, ait, bit, iit, useIter, swap, err = prepDataVV(a, b, reuse); err != nil { + return nil, errors.Wrapf(err, "StdEng.MinBetween") + } + // check to see if anything needs to be created + if reuse == nil { + if swap { + reuse = NewDense(b.Dtype(), b.Shape().Clone(), WithEngine(e)) + } else { + reuse = NewDense(a.Dtype(), a.Shape().Clone(), WithEngine(e)) + } + dataReuse = reuse.hdr() + if useIter { + iit = IteratorFromDense(reuse) + } + } + + if useIter { + switch { + case !safe && reuse == nil: + err = e.E.MinBetweenIter(typ, dataA, dataB, ait, bit) + retVal = a + case safe && reuse != nil: + storage.CopyIter(typ, dataReuse, dataA, iit, ait) + ait.Reset() + iit.Reset() + err = e.E.MinBetweenIter(typ, dataReuse, dataB, iit, bit) + retVal = reuse + default: // safe && bool + panic("Unreachable") + } + return + } + + // standard + switch { + case !safe && reuse == nil: + err = e.E.MinBetween(typ, dataA, dataB) + retVal = a + case safe && reuse != nil: + storage.Copy(typ, dataReuse, dataA) + err = e.E.MinBetween(typ, dataReuse, dataB) + retVal = reuse + default: + panic("Unreachable") + } + return +} + +func (e StdEng) MaxBetween(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err error) { + if err = binaryCheck(a, b, dtype.Ord); err != nil { + + return nil, errors.Wrapf(err, "MaxBetween failed") + } + + var reuse DenseTensor + var safe bool + var ctx context.Context + if ctx, reuse, safe, _, _, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { + return nil, errors.Wrap(err, "Unable to handle funcOpts") + } + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be noopError{}, no need to wrap. + } + typ := a.Dtype().Type + var dataA, dataB, dataReuse *storage.Header + var ait, bit, iit Iterator + var useIter, swap bool + if dataA, dataB, dataReuse, ait, bit, iit, useIter, swap, err = prepDataVV(a, b, reuse); err != nil { + return nil, errors.Wrapf(err, "StdEng.MaxBetween") + } + // check to see if anything needs to be created + if reuse == nil { + if swap { + reuse = NewDense(b.Dtype(), b.Shape().Clone(), WithEngine(e)) + } else { + reuse = NewDense(a.Dtype(), a.Shape().Clone(), WithEngine(e)) + } + dataReuse = reuse.hdr() + if useIter { + iit = IteratorFromDense(reuse) + } + } + + if useIter { + switch { + case !safe && reuse == nil: + err = e.E.MaxBetweenIter(typ, dataA, dataB, ait, bit) + retVal = a + case safe && reuse != nil: + storage.CopyIter(typ, dataReuse, dataA, iit, ait) + ait.Reset() + iit.Reset() + err = e.E.MaxBetweenIter(typ, dataReuse, dataB, iit, bit) + retVal = reuse + default: // safe && bool + panic("Unreachable") + } + return + } + + // standard + switch { + case !safe && reuse == nil: + err = e.E.MaxBetween(typ, dataA, dataB) + retVal = a + case safe && reuse != nil: + storage.Copy(typ, dataReuse, dataA) + err = e.E.MaxBetween(typ, dataReuse, dataB) + retVal = reuse + default: + panic("Unreachable") + } + return +} + +func (e StdEng) MinBetweenScalar(t Tensor, s interface{}, leftTensor bool, opts ...FuncOpt) (retVal Tensor, err error) { + if err = unaryCheck(t, dtype.Ord); err != nil { + return nil, errors.Wrapf(err, "MinBetween failed") + } + + if err = scalarDtypeCheck(t, s); err != nil { + return nil, errors.Wrap(err, "MinBetween failed") + } + + var reuse DenseTensor + var safe bool + var ctx context.Context + if ctx, reuse, safe, _, _, _, err = handleFuncOpts(t.Shape(), t.Dtype(), t.DataOrder(), true, opts...); err != nil { + return nil, errors.Wrap(err, "Unable to handle funcOpts") + } + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be noopError{}, no need to wrap. + } + a := t + typ := t.Dtype().Type + var ait, bit, iit Iterator + var dataA, dataB, dataReuse, scalarHeader *storage.Header + var useIter, newAlloc bool + + if leftTensor { + if dataA, dataB, dataReuse, ait, iit, useIter, newAlloc, err = prepDataVS(t, s, reuse); err != nil { + return nil, errors.Wrapf(err, opFail, "StdEng.MinBetween") + } + scalarHeader = dataB + } else { + if dataA, dataB, dataReuse, bit, iit, useIter, newAlloc, err = prepDataSV(s, t, reuse); err != nil { + return nil, errors.Wrapf(err, opFail, "StdEng.MinBetween") + } + scalarHeader = dataA + } + + // check to see if anything needs to be created + if reuse == nil { + reuse = NewDense(a.Dtype(), a.Shape().Clone(), WithEngine(e)) + dataReuse = reuse.hdr() + if useIter { + iit = IteratorFromDense(reuse) + } + } + + if useIter { + switch { + case !safe && reuse == nil: + err = e.E.MinBetweenIter(typ, dataA, dataB, ait, bit) + retVal = a + case safe && reuse != nil && !leftTensor: + storage.CopyIter(typ, dataReuse, dataB, iit, bit) + bit.Reset() + iit.Reset() + err = e.E.MinBetweenIter(typ, dataA, dataReuse, ait, bit) + retVal = reuse + case safe && reuse != nil && leftTensor: + storage.CopyIter(typ, dataReuse, dataA, iit, ait) + ait.Reset() + iit.Reset() + err = e.E.MinBetweenIter(typ, dataReuse, dataB, iit, bit) + retVal = reuse + default: // safe && bool + panic("Unreachable") + } + if newAlloc { + freeScalar(scalarHeader.Raw) + } + returnHeader(scalarHeader) + return + } + + // handle special case where A and B have both len 1 + if len(dataA.Raw) == int(typ.Size()) && len(dataB.Raw) == int(typ.Size()) { + switch { + case safe && reuse != nil && leftTensor: + storage.Copy(typ, dataReuse, dataA) + err = e.E.MinBetween(typ, dataReuse, dataB) + retVal = reuse + return + case safe && reuse != nil && !leftTensor: + storage.Copy(typ, dataReuse, dataB) + err = e.E.MinBetween(typ, dataReuse, dataA) + retVal = reuse + return + } + } + // standard + switch { + case !safe && reuse == nil: + err = e.E.MinBetween(typ, dataA, dataB) + retVal = a + case safe && reuse != nil && leftTensor: + storage.Copy(typ, dataReuse, dataA) + err = e.E.MinBetween(typ, dataReuse, dataB) + retVal = reuse + case safe && reuse != nil && !leftTensor: + storage.Copy(typ, dataReuse, dataB) + err = e.E.MinBetween(typ, dataA, dataReuse) + retVal = reuse + default: + panic("Unreachable") + } + if newAlloc { + freeScalar(scalarHeader.Raw) + } + returnHeader(scalarHeader) + return +} + +func (e StdEng) MaxBetweenScalar(t Tensor, s interface{}, leftTensor bool, opts ...FuncOpt) (retVal Tensor, err error) { + if err = unaryCheck(t, dtype.Ord); err != nil { + return nil, errors.Wrapf(err, "MaxBetween failed") + } + + if err = scalarDtypeCheck(t, s); err != nil { + return nil, errors.Wrap(err, "MaxBetween failed") + } + + var reuse DenseTensor + var safe bool + var ctx context.Context + if ctx, reuse, safe, _, _, _, err = handleFuncOpts(t.Shape(), t.Dtype(), t.DataOrder(), true, opts...); err != nil { + return nil, errors.Wrap(err, "Unable to handle funcOpts") + } + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be noopError{}, no need to wrap. + } + a := t + typ := t.Dtype().Type + var ait, bit, iit Iterator + var dataA, dataB, dataReuse, scalarHeader *storage.Header + var useIter, newAlloc bool + + if leftTensor { + if dataA, dataB, dataReuse, ait, iit, useIter, newAlloc, err = prepDataVS(t, s, reuse); err != nil { + return nil, errors.Wrapf(err, opFail, "StdEng.MaxBetween") + } + scalarHeader = dataB + } else { + if dataA, dataB, dataReuse, bit, iit, useIter, newAlloc, err = prepDataSV(s, t, reuse); err != nil { + return nil, errors.Wrapf(err, opFail, "StdEng.MaxBetween") + } + scalarHeader = dataA + } + + // check to see if anything needs to be created + if reuse == nil { + reuse = NewDense(a.Dtype(), a.Shape().Clone(), WithEngine(e)) + dataReuse = reuse.hdr() + if useIter { + iit = IteratorFromDense(reuse) + } + } + + if useIter { + switch { + case !safe && reuse == nil: + err = e.E.MaxBetweenIter(typ, dataA, dataB, ait, bit) + retVal = a + case safe && reuse != nil && !leftTensor: + storage.CopyIter(typ, dataReuse, dataB, iit, bit) + bit.Reset() + iit.Reset() + err = e.E.MaxBetweenIter(typ, dataA, dataReuse, ait, bit) + retVal = reuse + case safe && reuse != nil && leftTensor: + storage.CopyIter(typ, dataReuse, dataA, iit, ait) + ait.Reset() + iit.Reset() + err = e.E.MaxBetweenIter(typ, dataReuse, dataB, iit, bit) + retVal = reuse + default: // safe && bool + panic("Unreachable") + } + if newAlloc { + freeScalar(scalarHeader.Raw) + } + returnHeader(scalarHeader) + return + } + + // handle special case where A and B have both len 1 + if len(dataA.Raw) == int(typ.Size()) && len(dataB.Raw) == int(typ.Size()) { + switch { + case safe && reuse != nil && leftTensor: + storage.Copy(typ, dataReuse, dataA) + err = e.E.MaxBetween(typ, dataReuse, dataB) + retVal = reuse + return + case safe && reuse != nil && !leftTensor: + storage.Copy(typ, dataReuse, dataB) + err = e.E.MaxBetween(typ, dataReuse, dataA) + retVal = reuse + return + } + } + // standard + switch { + case !safe && reuse == nil: + err = e.E.MaxBetween(typ, dataA, dataB) + retVal = a + case safe && reuse != nil && leftTensor: + storage.Copy(typ, dataReuse, dataA) + err = e.E.MaxBetween(typ, dataReuse, dataB) + retVal = reuse + case safe && reuse != nil && !leftTensor: + storage.Copy(typ, dataReuse, dataB) + err = e.E.MaxBetween(typ, dataA, dataReuse) + retVal = reuse + default: + panic("Unreachable") + } + if newAlloc { + freeScalar(scalarHeader.Raw) + } + returnHeader(scalarHeader) + return +} diff --git a/defaultengine_misc.go b/defaultengine_misc.go index b4bf21c..6a0c570 100644 --- a/defaultengine_misc.go +++ b/defaultengine_misc.go @@ -1,81 +1,94 @@ -package tensor - -import ( - "github.com/pkg/errors" - "gorgonia.org/tensor/internal/storage" -) - -func (e StdEng) Clamp(a Tensor, min, max interface{}, opts ...FuncOpt) (retVal Tensor, err error) { - if err = unaryCheck(a, nonComplexNumberTypes); err != nil { - return nil, errors.Wrap(err, "Clamp failed") - } - - var reuse DenseTensor - var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), false, opts...); err != nil { - return nil, errors.Wrap(err, "Unable to handle funcOpts") - } - - typ := a.Dtype().Type - var ait, rit Iterator - var dataA, dataReuse *storage.Header - var useIter bool - - if dataA, dataReuse, ait, rit, useIter, err = prepDataUnary(a, reuse); err != nil { - return nil, errors.Wrapf(err, opFail, "StdEng.Neg") - } - - if useIter { - switch { - case incr: - cloned := a.Clone().(Tensor) - if err = e.E.ClampIter(typ, cloned.hdr(), ait, min, max); err != nil { - return nil, errors.Wrapf(err, "Unable to perform Clamp") - } - ait.Reset() - err = e.E.AddIter(typ, dataReuse, cloned.hdr(), rit, ait) - retVal = reuse - case toReuse: - storage.CopyIter(typ, dataReuse, dataA, rit, ait) - rit.Reset() - err = e.E.ClampIter(typ, dataReuse, rit, min, max) - retVal = reuse - case !safe: - err = e.E.ClampIter(typ, dataA, ait, min, max) - retVal = a - default: - cloned := a.Clone().(Tensor) - err = e.E.ClampIter(typ, cloned.hdr(), ait, min, max) - retVal = cloned - } - return - } - switch { - case incr: - cloned := a.Clone().(Tensor) - if err = e.E.Clamp(typ, cloned.hdr(), min, max); err != nil { - return nil, errors.Wrapf(err, "Unable to perform Clamp") - } - err = e.E.Add(typ, dataReuse, cloned.hdr()) - retVal = reuse - case toReuse: - storage.Copy(typ, dataReuse, dataA) - err = e.E.Clamp(typ, dataReuse, min, max) - retVal = reuse - case !safe: - err = e.E.Clamp(typ, dataA, min, max) - retVal = a - default: - cloned := a.Clone().(Tensor) - err = e.E.Clamp(typ, cloned.hdr(), min, max) - retVal = cloned - } - return -} - -func (e StdEng) FMA(a, x, y Tensor) (Tensor, error) { - return e.Mul(a, x, WithIncr(y)) -} -func (e StdEng) FMAScalar(a Tensor, x interface{}, y Tensor) (Tensor, error) { - return e.MulScalar(a, x, true, WithIncr(y)) -} +package tensor + +import ( + "context" + + "github.com/pkg/errors" + "gorgonia.org/dtype" + "gorgonia.org/tensor/internal/storage" +) + +func (e StdEng) Clamp(a Tensor, min, max interface{}, opts ...FuncOpt) (retVal Tensor, err error) { + if err = unaryCheck(a, dtype.NonComplexNumber); err != nil { + return nil, errors.Wrap(err, "Clamp failed") + } + + var reuse DenseTensor + var safe, toReuse, incr bool + var ctx context.Context + if ctx, reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), false, opts...); err != nil { + return nil, errors.Wrap(err, "Unable to handle funcOpts") + } + if err = handleCtx(ctx); err != nil { + return nil, err // will be noopError{}, no need to wrap.s + } + + typ := a.Dtype().Type + var ait, rit Iterator + var dataA, dataReuse *storage.Header + var useIter bool + + if dataA, dataReuse, ait, rit, useIter, err = prepDataUnary(a, reuse); err != nil { + return nil, errors.Wrapf(err, opFail, "StdEng.Neg") + } + + if useIter { + switch { + case incr: + cloned := a.Clone().(Tensor) + if err = e.E.ClampIter(typ, cloned.hdr(), ait, min, max); err != nil { + return nil, errors.Wrapf(err, "Unable to perform Clamp") + } + ait.Reset() + err = e.E.AddIter(typ, dataReuse, cloned.hdr(), rit, ait) + retVal = reuse + case toReuse: + storage.CopyIter(typ, dataReuse, dataA, rit, ait) + rit.Reset() + err = e.E.ClampIter(typ, dataReuse, rit, min, max) + retVal = reuse + case !safe: + err = e.E.ClampIter(typ, dataA, ait, min, max) + retVal = a + default: + cloned := a.Clone().(Tensor) + err = e.E.ClampIter(typ, cloned.hdr(), ait, min, max) + retVal = cloned + } + return + } + switch { + case incr: + cloned := a.Clone().(Tensor) + if err = e.E.Clamp(typ, cloned.hdr(), min, max); err != nil { + return nil, errors.Wrapf(err, "Unable to perform Clamp") + } + err = e.E.Add(typ, dataReuse, cloned.hdr()) + retVal = reuse + case toReuse: + storage.Copy(typ, dataReuse, dataA) + err = e.E.Clamp(typ, dataReuse, min, max) + retVal = reuse + case !safe: + err = e.E.Clamp(typ, dataA, min, max) + retVal = a + default: + cloned := a.Clone().(Tensor) + err = e.E.Clamp(typ, cloned.hdr(), min, max) + retVal = cloned + } + return +} + +func (e StdEng) FMA(ctx context.Context, a, x, y Tensor) (Tensor, error) { + if err := handleCtx(ctx); err != nil { + return nil, err + } + return e.Mul(a, x, WithIncr(y)) +} +func (e StdEng) FMAScalar(ctx context.Context, a Tensor, x interface{}, y Tensor) (Tensor, error) { + if err := handleCtx(ctx); err != nil { + return nil, err + } + return e.MulScalar(a, x, true, WithIncr(y)) +} diff --git a/defaultengine_prep.go b/defaultengine_prep.go index cb358a7..6f6927a 100644 --- a/defaultengine_prep.go +++ b/defaultengine_prep.go @@ -1,12 +1,17 @@ package tensor import ( + "context" + "reflect" + "github.com/pkg/errors" + "gorgonia.org/dtype" "gorgonia.org/tensor/internal/storage" ) -func handleFuncOpts(expShape Shape, expType Dtype, strict bool, opts ...FuncOpt) (reuse DenseTensor, safe, toReuse, incr, same bool, err error) { +func handleFuncOpts(expShape Shape, expType dtype.Dtype, o DataOrder, strict bool, opts ...FuncOpt) (ctx context.Context, reuse DenseTensor, safe, toReuse, incr, same bool, err error) { fo := ParseFuncOpts(opts...) + ctx = fo.Context() reuseT, incr := fo.IncrReuse() safe = fo.Safe() @@ -16,7 +21,7 @@ func handleFuncOpts(expShape Shape, expType Dtype, strict bool, opts ...FuncOpt) if toReuse { if reuse, err = getDenseTensor(reuseT); err != nil { returnOpOpt(fo) - err = errors.Wrapf(err, "Cannot reuse a different type of Tensor in a *Dense-Scalar operation") + err = errors.Wrapf(err, "Cannot reuse a Tensor that isn't a DenseTensor. Got %T instead", reuseT) return } @@ -39,13 +44,35 @@ func handleFuncOpts(expShape Shape, expType Dtype, strict bool, opts ...FuncOpt) err = errors.Wrapf(err, "Cannot use reuse: shape mismatch - reuse.len() %v, expShape.TotalSize() %v", reuse.len(), expShape.TotalSize()) return } + if !reuse.Shape().Eq(expShape) { + cloned := expShape.Clone() + if err = reuse.Reshape(cloned...); err != nil { + return + + } + ReturnInts([]int(cloned)) + } + + if !incr && reuse != nil { + reuse.setDataOrder(o) + // err = reuse.reshape(expShape...) + } } returnOpOpt(fo) return } -func binaryCheck(a, b Tensor, tc *typeclass) (err error) { +func handleCtx(ctx context.Context) error { + select { + case <-ctx.Done(): + return noopError{} + default: + } + return nil +} + +func binaryCheck(a, b Tensor, tc dtype.TypeClass) (err error) { // check if the tensors are accessible if !a.IsNativelyAccessible() { return errors.Errorf(inaccessibleData, a) @@ -57,11 +84,11 @@ func binaryCheck(a, b Tensor, tc *typeclass) (err error) { at := a.Dtype() bt := b.Dtype() - if tc != nil { - if err = typeclassCheck(at, tc); err != nil { + if tc != nilTC { + if err = dtype.TypeClassCheck(at, tc); err != nil { return errors.Wrapf(err, typeclassMismatch, "a") } - if err = typeclassCheck(bt, tc); err != nil { + if err = dtype.TypeClassCheck(bt, tc); err != nil { return errors.Wrapf(err, typeclassMismatch, "b") } } @@ -75,19 +102,36 @@ func binaryCheck(a, b Tensor, tc *typeclass) (err error) { return nil } -func unaryCheck(a Tensor, tc *typeclass) error { +func unaryCheck(a Tensor, tc dtype.TypeClass) error { if !a.IsNativelyAccessible() { return errors.Errorf(inaccessibleData, a) } at := a.Dtype() - if tc != nil { - if err := typeclassCheck(at, tc); err != nil { + if tc != nilTC { + if err := dtype.TypeClassCheck(at, tc); err != nil { return errors.Wrapf(err, typeclassMismatch, "a") } } return nil } +// scalarDtypeCheck checks that a scalar value has the same dtype as the dtype of a given tensor. +func scalarDtypeCheck(a Tensor, b interface{}) error { + var dt dtype.Dtype + switch bt := b.(type) { + case Dtyper: + dt = bt.Dtype() + default: + t := reflect.TypeOf(b) + dt = dtype.Dtype{t} + } + + if a.Dtype() != dt { + return errors.Errorf("Expected scalar to have the same Dtype as the tensor (%v). Got %T instead ", a.Dtype(), b) + } + return nil +} + // prepDataVV prepares the data given the input and reuse tensors. It also retruns several indicators // // useIter indicates that the iterator methods should be used. @@ -101,7 +145,11 @@ func prepDataVV(a, b Tensor, reuse Tensor) (dataA, dataB, dataReuse *storage.Hea } // iter - useIter = a.RequiresIterator() || b.RequiresIterator() || (reuse != nil && reuse.RequiresIterator()) + useIter = a.RequiresIterator() || + b.RequiresIterator() || + (reuse != nil && reuse.RequiresIterator()) || + !a.DataOrder().HasSameOrder(b.DataOrder()) || + (reuse != nil && (!a.DataOrder().HasSameOrder(reuse.DataOrder()) || !b.DataOrder().HasSameOrder(reuse.DataOrder()))) if useIter { ait = a.Iterator() bit = b.Iterator() @@ -122,10 +170,10 @@ func prepDataVV(a, b Tensor, reuse Tensor) (dataA, dataB, dataReuse *storage.Hea return } -func prepDataVS(a Tensor, b interface{}, reuse Tensor) (dataA, dataB, dataReuse *storage.Header, ait, iit Iterator, useIter bool, err error) { +func prepDataVS(a Tensor, b interface{}, reuse Tensor) (dataA, dataB, dataReuse *storage.Header, ait, iit Iterator, useIter bool, newAlloc bool, err error) { // get data dataA = a.hdr() - dataB = scalarToHeader(b) + dataB, newAlloc = scalarToHeader(b) if reuse != nil { dataReuse = reuse.hdr() } @@ -133,19 +181,21 @@ func prepDataVS(a Tensor, b interface{}, reuse Tensor) (dataA, dataB, dataReuse if a.IsScalar() { return } - if a.RequiresIterator() || (reuse != nil && reuse.RequiresIterator()) { + useIter = a.RequiresIterator() || + (reuse != nil && reuse.RequiresIterator()) || + (reuse != nil && !reuse.DataOrder().HasSameOrder(a.DataOrder())) + if useIter { ait = a.Iterator() if reuse != nil { iit = reuse.Iterator() } - useIter = true } return } -func prepDataSV(a interface{}, b Tensor, reuse Tensor) (dataA, dataB, dataReuse *storage.Header, bit, iit Iterator, useIter bool, err error) { +func prepDataSV(a interface{}, b Tensor, reuse Tensor) (dataA, dataB, dataReuse *storage.Header, bit, iit Iterator, useIter bool, newAlloc bool, err error) { // get data - dataA = scalarToHeader(a) + dataA, newAlloc = scalarToHeader(a) dataB = b.hdr() if reuse != nil { dataReuse = reuse.hdr() @@ -155,12 +205,15 @@ func prepDataSV(a interface{}, b Tensor, reuse Tensor) (dataA, dataB, dataReuse if b.IsScalar() { return } - if b.RequiresIterator() || (reuse != nil && reuse.RequiresIterator()) { + useIter = b.RequiresIterator() || + (reuse != nil && reuse.RequiresIterator()) || + (reuse != nil && !reuse.DataOrder().HasSameOrder(b.DataOrder())) + + if useIter { bit = b.Iterator() if reuse != nil { iit = reuse.Iterator() } - useIter = true } return } diff --git a/defaultengine_selbyidx.go b/defaultengine_selbyidx.go new file mode 100644 index 0000000..58e3e42 --- /dev/null +++ b/defaultengine_selbyidx.go @@ -0,0 +1,259 @@ +package tensor + +import ( + "context" + + "github.com/pkg/errors" + "gorgonia.org/tensor/internal/storage" + + "reflect" +) + +// SelectByIndices selects the values given the in `indices` tensor. +// +// Currently SelectByIndices only supports Dense tensors that do not require the use of iterators. +// Please make a pull request to support tensors that require the use of an iterator to traverse data. +func (e StdEng) SelectByIndices(a, indices Tensor, axis int, opts ...FuncOpt) (retVal Tensor, err error) { + if !indices.Shape().IsVectorLike() { + return nil, errors.Errorf("Expected indices to be a vector. Got %v instead", indices.Shape()) + } + if indices.Dtype() != Int { + return nil, errors.Errorf("Expected indices to be a vector of ints. Got %v instead", indices.Dtype()) + } + // if b is a scalar, then use Slice + if a.Shape().IsScalarEquiv() { + slices := make([]Slice, a.Shape().Dims()) + slices[axis] = ss(getInts(indices)[0]) + return a.Slice(slices...) + } + + expectedShape := a.Shape().Clone() + expectedShape[axis] = indices.Shape().TotalSize() + + var reuse DenseTensor + var safe, toReuse, _ bool + var ctx context.Context + if ctx, reuse, safe, toReuse, _, _, err = handleFuncOpts(expectedShape, a.Dtype(), a.DataOrder(), true, opts...); err != nil { + return nil, errors.Wrap(err, "Unable to handle funcOpts") + } + if err = handleCtx(ctx); err != nil { + return nil, err // will be noopError{}, no need to wrap. + } + if safe || !toReuse && reuse == nil && safe { + // create reuse + reuse = New(WithShape(expectedShape...), Of(a.Dtype())) + } + + if !safe { + if a.Shape()[axis] != indices.Shape().TotalSize() { + expected := a.Shape().Clone() + expected[axis] = indices.Shape().TotalSize() + return nil, errors.Errorf("Expected a safe resuse to have the same shape as the expected shape of the result: %v. The input a has %v ", expected, a.Shape()) + } + + reuse = a.(DenseTensor) + } + + typ := a.Dtype().Type + var dataA, dataB, dataReuse *storage.Header + var ait, bit, iit Iterator + var useIter bool + if dataA, dataB, dataReuse, ait, bit, iit, useIter, _, err = prepDataVV(a, indices, reuse); err != nil { + return nil, errors.Wrapf(err, "StdEng.Add") + } + + if useIter { + e.iterSelectByIdx(axis, dataA, dataB, dataReuse, ait, bit, iit) + //TODO + return + } + + e.selectByIdx(axis, dataB.Ints(), typ, dataA, dataReuse, a.(*Dense).AP, reuse.(*Dense).AP) + return reuse, nil +} + +func (e StdEng) iterSelectByIdx(axis int, dataA, dataB, dataReuse *storage.Header, ait, bit, iit Iterator) { + panic("iterSelectByIdx is not yet implemented") +} + +func (e StdEng) selectByIdx(axis int, indices []int, typ reflect.Type, dataA, dataRetVal *storage.Header, apA, apRet AP) { + isInnermost := axis == apA.shape.Dims()-1 + + outer := ProdInts(apA.shape[:axis]) + + axStride := apA.strides[axis] + retStride := apRet.strides[axis] + var outerRetStride int + if axis == 0 { + // then it's the outermost + outerRetStride = apRet.strides[axis] * 2 + } else { + outerRetStride = apRet.strides[axis-1] + } + + srcCoord := make([]int, apA.shape.Dims()) + dstCoord := make([]int, apRet.shape.Dims()) + + if isInnermost { + prevAxis := axis - 1 + if prevAxis < 0 { + // this may be the case if input is a vector + prevAxis = 0 + } + prevStride := apA.strides[prevAxis] + retPrevStride := apRet.strides[prevAxis] + for i, idx := range indices { + srcCoord[axis] = idx + dstCoord[axis] = i + start, _ := Ltoi(apA.shape, apA.strides, srcCoord...) + dstStart, _ := Ltoi(apRet.shape, apRet.strides, dstCoord...) + for o := 0; o < outer; o++ { + end := start + axStride + dstEnd := dstStart + retStride + storage.CopySliced(typ, dataRetVal, dstStart, dstEnd, dataA, start, end) + + start += prevStride + dstStart += retPrevStride + + } + } + return + } + + for i, idx := range indices { + srcCoord[axis] = idx + dstCoord[axis] = i + start, _ := Ltoi(apA.shape, apA.strides, srcCoord...) + dstStart, _ := Ltoi(apRet.shape, apRet.strides, dstCoord...) + + for o := 0; o < outer; o++ { + end := start + axStride + dstEnd := dstStart + retStride + + storage.CopySliced(typ, dataRetVal, dstStart, dstEnd, dataA, start, end) + + start = end + axStride + dstStart = dstEnd + (outerRetStride - retStride) + } + } +} + +// SelectByIndicesB computes the gradient of the result of `SelectByIndices`. +// +// Currently SelectByIndicesB only supports Dense tensors that do not require the use of iterators. +// Please make a pull request to support tensors that require the use of an iterator to traverse data. +func (e StdEng) SelectByIndicesB(input, outGrad, indices Tensor, axis int, opts ...FuncOpt) (retVal Tensor, err error) { + if !indices.Shape().IsVectorLike() { + return nil, errors.Errorf("Expected indices to be a vector. Got %v instead", outGrad.Shape()) + } + if indices.Dtype() != Int { + return nil, errors.Errorf("Expected indices to be a vector of ints. Got %v instead", outGrad.Dtype()) + } + + // if b is a scalar, then use Slice + if input.Shape().IsScalarEquiv() { + slices := make([]Slice, input.Shape().Dims()) + slices[axis] = ss(outGrad.Data().([]int)[0]) + return input.Slice(slices...) + } + + expectedShape := input.Shape().Clone() + + var reuse DenseTensor + var _, toReuse, _ bool + var ctx context.Context + if ctx, reuse, _, toReuse, _, _, err = handleFuncOpts(input.Shape(), input.Dtype(), input.DataOrder(), true, opts...); err != nil { + return nil, errors.Wrap(err, "Unable to handle funcOpts") + } + if err = handleCtx(ctx); err != nil { + return nil, err // will be noopError{}, no need to wrap. + } + if !toReuse && reuse == nil { + // create reuse + reuse = New(WithShape(expectedShape...), Of(input.Dtype())) + } + + typ := input.Dtype().Type + var _, dataB, dataReuse *storage.Header + var _, bit, iit Iterator + var useIter bool + if _, dataB, dataReuse, _, bit, iit, useIter, _, err = prepDataVV(input, outGrad, reuse); err != nil { + return nil, errors.Wrapf(err, "StdEng.SelectByIndicesB") + } + + if useIter { + e.iterSelectByIndicesB(axis, dataB, dataReuse, bit, iit) + //TODO + return + } + + e.selectByIndicesB(axis, getInts(indices), typ, dataB, dataReuse, outGrad.(*Dense).AP, reuse.(*Dense).AP) + + return reuse, nil +} + +func (e StdEng) iterSelectByIndicesB(axis int, dataB, dataGradA *storage.Header, bit, iit Iterator) { + panic("iterSelectByIndicesB not implemented yet") +} + +func (e StdEng) selectByIndicesB(axis int, indices []int, typ reflect.Type, dataB, dataGradA *storage.Header, apB, apRet AP) { + isInnermost := axis == apB.shape.Dims()-1 + + outer := ProdInts(apB.shape[:axis]) + + axStride := apB.strides[axis] + retStride := apRet.strides[axis] + var outerRetStride int + if axis == 0 { + outerRetStride = apRet.strides[axis] * 2 + } else { + outerRetStride = apRet.strides[axis-1] + } + + dstCoord := make([]int, apB.shape.Dims()) + srcCoord := make([]int, apRet.shape.Dims()) + + if isInnermost { + prevAxis := axis - 1 + if prevAxis < 0 { + // this may be the case if input is a vector + prevAxis = 0 + } + retPrevStride := apB.strides[prevAxis] + prevStride := apRet.strides[prevAxis] + for i, idx := range indices { + dstCoord[axis] = idx + srcCoord[axis] = i + dstStart, _ := Ltoi(apB.shape, apB.strides, dstCoord...) + start, _ := Ltoi(apRet.shape, apRet.strides, srcCoord...) + for o := 0; o < outer; o++ { + dstEnd := dstStart + axStride + end := start + retStride + + e.E.AddSliced(typ, dataGradA, dstStart, dstEnd, dataB, start, end) + + dstStart += prevStride + start += retPrevStride + + } + } + return + } + + for i, idx := range indices { + dstCoord[axis] = idx + srcCoord[axis] = i + dstStart, _ := Ltoi(apRet.shape, apRet.strides, dstCoord...) + start, _ := Ltoi(apB.shape, apB.strides, srcCoord...) + + for o := 0; o < outer; o++ { + dstEnd := dstStart + axStride + end := start + retStride + + e.E.AddSliced(typ, dataGradA, dstStart, dstEnd, dataB, start, end) + + dstStart = dstEnd + axStride + start = end + (outerRetStride - retStride) + } + } +} diff --git a/defaultengine_softmax.go b/defaultengine_softmax.go new file mode 100644 index 0000000..8a7dc3e --- /dev/null +++ b/defaultengine_softmax.go @@ -0,0 +1,703 @@ +package tensor + +import ( + "context" + "fmt" + "math" + "sync" + + "github.com/chewxy/math32" + "github.com/pkg/errors" +) + +// if dims = 2 and axis -1 it returns the last dimension. In this case 1 +func resolveAxis(axis int, dims int) int { + res := axis % dims + if (res < 0 && dims > 0) || (res > 0 && dims < 0) { + return res + dims + } + + return res +} + +// SoftMax performs the softmax operation on the given tensor. Currently it expects the tensor to be a Dense tensor. +// Please make a pull request to support sparse tensors. +// +// The softmax function is defined as : +// σ(x) = e^x_i / Σ(e^x_i) +func (e StdEng) SoftMax(x Tensor, axis int, opts ...FuncOpt) (retVal Tensor, err error) { + axis = resolveAxis(axis, x.Dims()) + expectedShape := x.Shape() + + var reuse DenseTensor + var safe, toReuse, _ bool + var ctx context.Context + if ctx, reuse, safe, toReuse, _, _, err = handleFuncOpts(expectedShape, x.Dtype(), x.DataOrder(), true, opts...); err != nil { + return nil, errors.Wrap(err, "Unable to handle funcOpts") + } + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be noopError{}, no need to wrap. + } + + if safe || !toReuse && reuse == nil && safe { + // create reuse + reuse = New(WithShape(expectedShape...), Of(x.Dtype())) + } + + switch x.Dtype() { + case Float32: + if expectedShape.Dims()-1 == axis { + e.softMaxLastDimF32(reuse, x, axis, false) + } else { + e.softMaxInnerDimF32(reuse, x, axis, false) + } + case Float64: + if expectedShape.Dims()-1 == axis { + e.softMaxLastDimF64(reuse, x, axis, false) + } else { + e.softMaxInnerDimF64(reuse, x, axis, false) + } + default: + return nil, fmt.Errorf("type %v not supported", x.Dtype()) + } + + return reuse, nil +} + +// SoftMaxB computes gradient of the input `x`, given the `output = SoftMax(x)` and its associated gradient. Currently it expects the tensor to be a Dense tensor. +// Please make a pull request to support sparse tensors. +func (e StdEng) SoftMaxB(output, grad Tensor, axis int, opts ...FuncOpt) (retVal Tensor, err error) { + if !output.Shape().Eq(grad.Shape()) { + return nil, fmt.Errorf("output and grad shapes don't match") + } + + if !output.Dtype().Eq(grad.Dtype()) { + return nil, fmt.Errorf("output and grad types don't match") + } + + axis = resolveAxis(axis, output.Dims()) + expectedShape := output.Shape() + + var reuse DenseTensor + var safe, toReuse, _ bool + var ctx context.Context + if ctx, reuse, safe, toReuse, _, _, err = handleFuncOpts(expectedShape, output.Dtype(), output.DataOrder(), true, opts...); err != nil { + return nil, errors.Wrap(err, "Unable to handle funcOpts") + } + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be noopError{}, no need to wrap. + } + + if safe || !toReuse && reuse == nil && safe { + // create reuse + reuse = New(WithShape(expectedShape...), Of(output.Dtype())) + } + + switch output.Dtype() { + case Float32: + if expectedShape.Dims()-1 == axis { + e.softMaxBLastDimF32(reuse, output, grad, axis, false) + } else { + e.softMaxBInnerDimF32(reuse, output, grad, axis, false) + } + case Float64: + if expectedShape.Dims()-1 == axis { + e.softMaxBLastDimF64(reuse, output, grad, axis, false) + } else { + e.softMaxBInnerDimF64(reuse, output, grad, axis, false) + } + default: + return nil, fmt.Errorf("type %v not supported", output.Dtype()) + } + + return reuse, nil +} + +// LogSoftMax performs softmax but in log space. This provides some amount of numerical stabilization. +// Conceptually it is the same as performing a logarithm after applying the softmax function. +// Currently it expects the tensor to be a Dense tensor. +// Please make a pull request to support sparse tensors. +func (e StdEng) LogSoftMax(x Tensor, axis int, opts ...FuncOpt) (retVal Tensor, err error) { + axis = resolveAxis(axis, x.Dims()) + expectedShape := x.Shape() + + var reuse DenseTensor + var safe, toReuse, _ bool + var ctx context.Context + if ctx, reuse, safe, toReuse, _, _, err = handleFuncOpts(expectedShape, x.Dtype(), x.DataOrder(), true, opts...); err != nil { + return nil, errors.Wrap(err, "Unable to handle funcOpts") + } + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be noopError{}, no need to wrap. + } + + if safe || !toReuse && reuse == nil && safe { + // create reuse + reuse = New(WithShape(expectedShape...), Of(x.Dtype())) + } + + switch x.Dtype() { + case Float32: + if expectedShape.Dims()-1 == axis { + e.softMaxLastDimF32(reuse, x, axis, true) + } else { + e.softMaxInnerDimF32(reuse, x, axis, true) + } + case Float64: + if expectedShape.Dims()-1 == axis { + e.softMaxLastDimF64(reuse, x, axis, true) + } else { + e.softMaxInnerDimF64(reuse, x, axis, true) + } + default: + return nil, fmt.Errorf("type %v not supported", x.Dtype()) + } + + return reuse, nil +} + +// LogSoftMaxB computes the gradient of the input `x`, given the `output = LogSoftmax(x)` and its associated gradient. +// Currently it expects the tensor to be a Dense tensor. +// Please make a pull request to support sparse tensors. +func (e StdEng) LogSoftMaxB(output, grad Tensor, axis int, opts ...FuncOpt) (retVal Tensor, err error) { + if !output.Shape().Eq(grad.Shape()) { + return nil, fmt.Errorf("output and grad shapes don't match") + } + + if !output.Dtype().Eq(grad.Dtype()) { + return nil, fmt.Errorf("output and grad types don't match") + } + + axis = resolveAxis(axis, output.Dims()) + expectedShape := output.Shape() + + var reuse DenseTensor + var safe, toReuse, _ bool + var ctx context.Context + if ctx, reuse, safe, toReuse, _, _, err = handleFuncOpts(expectedShape, output.Dtype(), output.DataOrder(), true, opts...); err != nil { + return nil, errors.Wrap(err, "Unable to handle funcOpts") + } + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be noopError{}, no need to wrap. + } + + if safe || !toReuse && reuse == nil && safe { + // create reuse + reuse = New(WithShape(expectedShape...), Of(output.Dtype())) + } + + switch output.Dtype() { + case Float32: + if expectedShape.Dims()-1 == axis { + e.softMaxBLastDimF32(reuse, output, grad, axis, true) + } else { + e.softMaxBInnerDimF32(reuse, output, grad, axis, true) + } + case Float64: + if expectedShape.Dims()-1 == axis { + e.softMaxBLastDimF64(reuse, output, grad, axis, true) + } else { + e.softMaxBInnerDimF64(reuse, output, grad, axis, true) + } + default: + return nil, fmt.Errorf("type %v not supported", output.Dtype()) + } + + return reuse, nil +} + +func (e StdEng) softMaxLastDimF64(output Tensor, x Tensor, axis int, logSoftMax bool) { + outputArr := getFloat64s(output) + xArr := getFloat64s(x) + + xShape := x.Shape() + + outerSize := 1 + dimSize := xShape[axis] + for i := 0; i < axis; i++ { + outerSize *= xShape[i] + } + + var wg sync.WaitGroup + for ii := 0; ii < outerSize; ii++ { + wg.Add(1) + go func(ii int, wg *sync.WaitGroup) { + maxInput := xArr[0] + for j := 1; j < dimSize; j++ { + i := ii*dimSize + j + + if xArr[i] > maxInput { + maxInput = xArr[i] + } + } + + sumExp := float64(0.0) + for j := 0; j < dimSize; j++ { + i := ii*dimSize + j + z := xArr[i] - maxInput + exp := math.Exp(z) + + if logSoftMax { + outputArr[i] = z + } else { + outputArr[i] = exp + } + + sumExp += exp + } + + if !logSoftMax { + sumExp = 1 / sumExp + } + + for j := 0; j < dimSize; j++ { + i := ii*dimSize + j + + if logSoftMax { + outputArr[i] -= math.Log(sumExp) + } else { + outputArr[i] *= sumExp + } + } + wg.Done() + }(ii, &wg) + + } + wg.Wait() +} + +func (e StdEng) softMaxBLastDimF64(inputGrad, output, grad Tensor, axis int, logSoftMax bool) { + dx := getFloat64s(inputGrad) + outputArr := getFloat64s(output) + gradArr := getFloat64s(grad) + + outputShape := output.Shape() + + outerSize := 1 + dimSize := outputShape[axis] + for i := 0; i < axis; i++ { + outerSize *= outputShape[i] + } + + var wg sync.WaitGroup + for ii := 0; ii < outerSize; ii++ { + wg.Add(1) + if logSoftMax { + go func(gradArr, dx []float64, ii int, wg *sync.WaitGroup) { + sum := gradArr[ii*dimSize] + for j := 1; j < dimSize; j++ { + i := ii*dimSize + j + + sum += gradArr[i] + } + + for j := 0; j < dimSize; j++ { + i := ii*dimSize + j + + dx[i] = gradArr[i] - (math.Exp(outputArr[i]) * sum) + } + wg.Done() + }(gradArr, dx, ii, &wg) + + } else { + go func(outputArr, gradArr, dx []float64, ii int, wg *sync.WaitGroup) { + //mul := make([]float64, dimSize) + var sum float64 + for j := 0; j < dimSize; j++ { + i := ii*dimSize + j + + //mul[j] = outputArr[i] * gradArr[i] + sum += outputArr[i] * gradArr[i] + } + + // sum := mul[0] + // for j := 1; j < dimSize; j++ { + // sum += mul[j] + // } + + for j := 0; j < dimSize; j++ { + i := ii*dimSize + j + dx[i] = (gradArr[i] - sum) * outputArr[i] + } + wg.Done() + }(outputArr, gradArr, dx, ii, &wg) + } + } + wg.Wait() +} + +func (e StdEng) softMaxInnerDimF64(output Tensor, x Tensor, axis int, logSoftmax bool) { + xShape := x.Shape() + + innerSize, outerSize := 1, 1 + for i := 0; i < axis; i++ { + outerSize *= xShape[i] + } + + for i := axis + 1; i < xShape.Dims(); i++ { + innerSize *= xShape[i] + } + + dimSize := xShape[axis] + dimStride := innerSize + outerStride := dimSize * dimStride + + outputArr := getFloat64s(output) + xArr := getFloat64s(x) + + var wg sync.WaitGroup + for ii := 0; ii < innerSize*outerSize; ii++ { + wg.Add(1) + go func(ii int, wg *sync.WaitGroup) { + outerIndex, innerIndex := divmod(ii, innerSize) + + inputPart := xArr[outerIndex*outerStride+innerIndex:] + outputPart := outputArr[outerIndex*outerStride+innerIndex:] + + maxInput := inputPart[0] + for j := 1; j < dimSize; j++ { + i := j * dimStride + + if inputPart[i] > maxInput { + maxInput = inputPart[i] + } + } + + sumExp := 0.0 + for j := 0; j < dimSize; j++ { + i := j * dimStride + + exp := math.Exp(inputPart[i] - maxInput) + + if !logSoftmax { + outputPart[i] = exp + } + + sumExp += exp + } + + if logSoftmax { + sumExp = math.Log(sumExp) + } else { + sumExp = 1 / sumExp + } + + for j := 0; j < dimSize; j++ { + i := j * dimStride + + if logSoftmax { + outputPart[i] = inputPart[i] - maxInput - sumExp + } else { + outputPart[i] *= sumExp + } + } + wg.Done() + }(ii, &wg) + } + wg.Wait() +} + +func (e StdEng) softMaxBInnerDimF64(inputGrad, output, grad Tensor, axis int, logSoftmax bool) { + dxShape := inputGrad.Shape() + + innerSize, outerSize := 1, 1 + for i := 0; i < axis; i++ { + outerSize *= dxShape[i] + } + + for i := axis + 1; i < dxShape.Dims(); i++ { + innerSize *= dxShape[i] + } + + dimSize := dxShape[axis] + dimStride := innerSize + outerStride := dimSize * dimStride + + dxArr := getFloat64s(inputGrad) + outputArr := getFloat64s(output) + gradArr := getFloat64s(grad) + + var wg sync.WaitGroup + for ii := 0; ii < innerSize*outerSize; ii++ { + wg.Add(1) + go func(ii int, wg *sync.WaitGroup) { + outerIndex, innerIndex := divmod(ii, innerSize) + + gradPart := gradArr[outerIndex*outerStride+innerIndex:] + dxPart := dxArr[outerIndex*outerStride+innerIndex:] + outputPart := outputArr[outerIndex*outerStride+innerIndex:] + + sum := 0.0 + for j := 0; j < dimSize; j++ { + i := j * dimStride + + if logSoftmax { + sum += gradPart[i] + } else { + sum += gradPart[i] * outputPart[i] + } + } + + for j := 0; j < dimSize; j++ { + i := j * dimStride + + if logSoftmax { + dxPart[i] = gradPart[i] - math.Exp(outputPart[i])*sum + } else { + dxPart[i] = outputPart[i] * (gradPart[i] - sum) + } + } + wg.Done() + }(ii, &wg) + + } + wg.Wait() +} + +func (e StdEng) softMaxLastDimF32(output Tensor, x Tensor, axis int, logSoftMax bool) { + outputArr := getFloat32s(output) + xArr := getFloat32s(x) + xShape := x.Shape() + + outerSize := 1 + dimSize := xShape[axis] + for i := 0; i < axis; i++ { + outerSize *= xShape[i] + } + + var wg sync.WaitGroup + for ii := 0; ii < outerSize; ii++ { + wg.Add(1) + go func(ii int, wg *sync.WaitGroup) { + maxInput := xArr[0] + for j := 1; j < dimSize; j++ { + i := ii*dimSize + j + + if xArr[i] > maxInput { + maxInput = xArr[i] + } + } + + sumExp := float32(0.0) + for j := 0; j < dimSize; j++ { + i := ii*dimSize + j + z := xArr[i] - maxInput + exp := math32.Exp(z) + + if logSoftMax { + outputArr[i] = z + } else { + outputArr[i] = exp + } + + sumExp += exp + } + + if !logSoftMax { + sumExp = 1 / sumExp + } + + for j := 0; j < dimSize; j++ { + i := ii*dimSize + j + + if logSoftMax { + outputArr[i] -= math32.Log(sumExp) + } else { + outputArr[i] *= sumExp + } + } + wg.Done() + }(ii, &wg) + } + wg.Wait() +} + +func (e StdEng) softMaxBLastDimF32(inputGrad, output, grad Tensor, axis int, logSoftMax bool) { + dx := getFloat32s(inputGrad) + outputArr := getFloat32s(output) + gradArr := getFloat32s(grad) + + outputShape := output.Shape() + + outerSize := 1 + dimSize := outputShape[axis] + for i := 0; i < axis; i++ { + outerSize *= outputShape[i] + } + + var wg sync.WaitGroup + for ii := 0; ii < outerSize; ii++ { + wg.Add(1) + + if logSoftMax { + go func(ii int, wg *sync.WaitGroup) { + sum := gradArr[ii*dimSize] + for j := 1; j < dimSize; j++ { + i := ii*dimSize + j + + sum += gradArr[i] + } + + for j := 0; j < dimSize; j++ { + i := ii*dimSize + j + + dx[i] = gradArr[i] - (math32.Exp(outputArr[i]) * sum) + } + wg.Done() + }(ii, &wg) + } else { + go func(ii int, wg *sync.WaitGroup) { + //mul := make([]float32, dimSize) + var sum float32 + for j := 0; j < dimSize; j++ { + i := ii*dimSize + j + + //mul[j] = outputArr[i] * gradArr[i] + sum += outputArr[i] * gradArr[i] + } + + // sum := mul[0] + // for j := 1; j < dimSize; j++ { + // sum += mul[j] + // } + + for j := 0; j < dimSize; j++ { + i := ii*dimSize + j + + dx[i] = (gradArr[i] - sum) * outputArr[i] + } + wg.Done() + }(ii, &wg) + } + } + wg.Wait() +} + +func (e StdEng) softMaxInnerDimF32(output Tensor, x Tensor, axis int, logSoftmax bool) { + xShape := x.Shape() + + innerSize, outerSize := 1, 1 + for i := 0; i < axis; i++ { + outerSize *= xShape[i] + } + + for i := axis + 1; i < xShape.Dims(); i++ { + innerSize *= xShape[i] + } + + dimSize := xShape[axis] + dimStride := innerSize + outerStride := dimSize * dimStride + + outputArr := getFloat32s(output) + xArr := getFloat32s(x) + + var wg sync.WaitGroup + for ii := 0; ii < innerSize*outerSize; ii++ { + wg.Add(1) + + go func(ii int, wg *sync.WaitGroup) { + outerIndex, innerIndex := divmod(ii, innerSize) + + inputPart := xArr[outerIndex*outerStride+innerIndex:] + outputPart := outputArr[outerIndex*outerStride+innerIndex:] + + maxInput := inputPart[0] + for j := 1; j < dimSize; j++ { + i := j * dimStride + + if inputPart[i] > maxInput { + maxInput = inputPart[i] + } + } + + sumExp := float32(0.0) + for j := 0; j < dimSize; j++ { + i := j * dimStride + + exp := math32.Exp(inputPart[i] - maxInput) + + if !logSoftmax { + outputPart[i] = exp + } + + sumExp += exp + } + + if logSoftmax { + sumExp = math32.Log(sumExp) + } else { + sumExp = 1 / sumExp + } + + for j := 0; j < dimSize; j++ { + i := j * dimStride + + if logSoftmax { + outputPart[i] = inputPart[i] - maxInput - sumExp + } else { + outputPart[i] *= sumExp + } + } + wg.Done() + }(ii, &wg) + } + wg.Wait() +} + +func (e StdEng) softMaxBInnerDimF32(inputGrad, output, grad Tensor, axis int, logSoftmax bool) { + dxShape := inputGrad.Shape() + + innerSize, outerSize := 1, 1 + for i := 0; i < axis; i++ { + outerSize *= dxShape[i] + } + + for i := axis + 1; i < dxShape.Dims(); i++ { + innerSize *= dxShape[i] + } + + dimSize := dxShape[axis] + dimStride := innerSize + outerStride := dimSize * dimStride + + dxArr := getFloat32s(inputGrad) + outputArr := getFloat32s(output) + gradArr := getFloat32s(grad) + + var wg sync.WaitGroup + for ii := 0; ii < innerSize*outerSize; ii++ { + wg.Add(1) + + go func(ii int, wg *sync.WaitGroup) { + outerIndex, innerIndex := divmod(ii, innerSize) + + gradPart := gradArr[outerIndex*outerStride+innerIndex:] + dxPart := dxArr[outerIndex*outerStride+innerIndex:] + outputPart := outputArr[outerIndex*outerStride+innerIndex:] + + sum := float32(0.0) + for j := 0; j < dimSize; j++ { + i := j * dimStride + + if logSoftmax { + sum += gradPart[i] + } else { + sum += gradPart[i] * outputPart[i] + } + } + + for j := 0; j < dimSize; j++ { + i := j * dimStride + + if logSoftmax { + dxPart[i] = gradPart[i] - math32.Exp(outputPart[i])*sum + } else { + dxPart[i] = outputPart[i] * (gradPart[i] - sum) + } + } + wg.Done() + }(ii, &wg) + } + wg.Wait() +} diff --git a/defaultengine_unary.go b/defaultengine_unary.go index 4da968a..8efe589 100644 --- a/defaultengine_unary.go +++ b/defaultengine_unary.go @@ -1,22 +1,29 @@ -// Code generated by genlib2. DO NOT EDIT. - package tensor import ( + "context" + "github.com/pkg/errors" + "gorgonia.org/dtype" "gorgonia.org/tensor/internal/storage" ) +// Code generated by genlib2. DO NOT EDIT. + func (e StdEng) Neg(a Tensor, opts ...FuncOpt) (retVal Tensor, err error) { - if err = unaryCheck(a, numberTypes); err != nil { + if err = unaryCheck(a, dtype.Number); err != nil { err = errors.Wrapf(err, "Neg failed") return } var reuse DenseTensor var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), true, opts...); err != nil { + var ctx context.Context + if ctx, reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be a noopError{}, no need to wrap. + } typ := a.Dtype().Type var ait, rit Iterator @@ -76,15 +83,19 @@ func (e StdEng) Neg(a Tensor, opts ...FuncOpt) (retVal Tensor, err error) { } func (e StdEng) Inv(a Tensor, opts ...FuncOpt) (retVal Tensor, err error) { - if err = unaryCheck(a, numberTypes); err != nil { + if err = unaryCheck(a, dtype.Number); err != nil { err = errors.Wrapf(err, "Inv failed") return } var reuse DenseTensor var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), true, opts...); err != nil { + var ctx context.Context + if ctx, reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be a noopError{}, no need to wrap. + } typ := a.Dtype().Type var ait, rit Iterator @@ -144,15 +155,19 @@ func (e StdEng) Inv(a Tensor, opts ...FuncOpt) (retVal Tensor, err error) { } func (e StdEng) Square(a Tensor, opts ...FuncOpt) (retVal Tensor, err error) { - if err = unaryCheck(a, numberTypes); err != nil { + if err = unaryCheck(a, dtype.Number); err != nil { err = errors.Wrapf(err, "Square failed") return } var reuse DenseTensor var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), true, opts...); err != nil { + var ctx context.Context + if ctx, reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be a noopError{}, no need to wrap. + } typ := a.Dtype().Type var ait, rit Iterator @@ -212,15 +227,19 @@ func (e StdEng) Square(a Tensor, opts ...FuncOpt) (retVal Tensor, err error) { } func (e StdEng) Cube(a Tensor, opts ...FuncOpt) (retVal Tensor, err error) { - if err = unaryCheck(a, numberTypes); err != nil { + if err = unaryCheck(a, dtype.Number); err != nil { err = errors.Wrapf(err, "Cube failed") return } var reuse DenseTensor var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), true, opts...); err != nil { + var ctx context.Context + if ctx, reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be a noopError{}, no need to wrap. + } typ := a.Dtype().Type var ait, rit Iterator @@ -280,15 +299,19 @@ func (e StdEng) Cube(a Tensor, opts ...FuncOpt) (retVal Tensor, err error) { } func (e StdEng) Exp(a Tensor, opts ...FuncOpt) (retVal Tensor, err error) { - if err = unaryCheck(a, floatcmplxTypes); err != nil { + if err = unaryCheck(a, dtype.FloatComplex); err != nil { err = errors.Wrapf(err, "Exp failed") return } var reuse DenseTensor var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), true, opts...); err != nil { + var ctx context.Context + if ctx, reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be a noopError{}, no need to wrap. + } typ := a.Dtype().Type var ait, rit Iterator @@ -348,15 +371,19 @@ func (e StdEng) Exp(a Tensor, opts ...FuncOpt) (retVal Tensor, err error) { } func (e StdEng) Tanh(a Tensor, opts ...FuncOpt) (retVal Tensor, err error) { - if err = unaryCheck(a, floatcmplxTypes); err != nil { + if err = unaryCheck(a, dtype.FloatComplex); err != nil { err = errors.Wrapf(err, "Tanh failed") return } var reuse DenseTensor var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), true, opts...); err != nil { + var ctx context.Context + if ctx, reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be a noopError{}, no need to wrap. + } typ := a.Dtype().Type var ait, rit Iterator @@ -416,15 +443,19 @@ func (e StdEng) Tanh(a Tensor, opts ...FuncOpt) (retVal Tensor, err error) { } func (e StdEng) Log(a Tensor, opts ...FuncOpt) (retVal Tensor, err error) { - if err = unaryCheck(a, floatcmplxTypes); err != nil { + if err = unaryCheck(a, dtype.FloatComplex); err != nil { err = errors.Wrapf(err, "Log failed") return } var reuse DenseTensor var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), true, opts...); err != nil { + var ctx context.Context + if ctx, reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be a noopError{}, no need to wrap. + } typ := a.Dtype().Type var ait, rit Iterator @@ -484,15 +515,19 @@ func (e StdEng) Log(a Tensor, opts ...FuncOpt) (retVal Tensor, err error) { } func (e StdEng) Log2(a Tensor, opts ...FuncOpt) (retVal Tensor, err error) { - if err = unaryCheck(a, floatTypes); err != nil { + if err = unaryCheck(a, dtype.Floats); err != nil { err = errors.Wrapf(err, "Log2 failed") return } var reuse DenseTensor var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), true, opts...); err != nil { + var ctx context.Context + if ctx, reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be a noopError{}, no need to wrap. + } typ := a.Dtype().Type var ait, rit Iterator @@ -552,15 +587,19 @@ func (e StdEng) Log2(a Tensor, opts ...FuncOpt) (retVal Tensor, err error) { } func (e StdEng) Log10(a Tensor, opts ...FuncOpt) (retVal Tensor, err error) { - if err = unaryCheck(a, floatcmplxTypes); err != nil { + if err = unaryCheck(a, dtype.FloatComplex); err != nil { err = errors.Wrapf(err, "Log10 failed") return } var reuse DenseTensor var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), true, opts...); err != nil { + var ctx context.Context + if ctx, reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be a noopError{}, no need to wrap. + } typ := a.Dtype().Type var ait, rit Iterator @@ -620,15 +659,19 @@ func (e StdEng) Log10(a Tensor, opts ...FuncOpt) (retVal Tensor, err error) { } func (e StdEng) Sqrt(a Tensor, opts ...FuncOpt) (retVal Tensor, err error) { - if err = unaryCheck(a, floatcmplxTypes); err != nil { + if err = unaryCheck(a, dtype.FloatComplex); err != nil { err = errors.Wrapf(err, "Sqrt failed") return } var reuse DenseTensor var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), true, opts...); err != nil { + var ctx context.Context + if ctx, reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be a noopError{}, no need to wrap. + } typ := a.Dtype().Type var ait, rit Iterator @@ -688,15 +731,19 @@ func (e StdEng) Sqrt(a Tensor, opts ...FuncOpt) (retVal Tensor, err error) { } func (e StdEng) Cbrt(a Tensor, opts ...FuncOpt) (retVal Tensor, err error) { - if err = unaryCheck(a, floatTypes); err != nil { + if err = unaryCheck(a, dtype.Floats); err != nil { err = errors.Wrapf(err, "Cbrt failed") return } var reuse DenseTensor var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), true, opts...); err != nil { + var ctx context.Context + if ctx, reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be a noopError{}, no need to wrap. + } typ := a.Dtype().Type var ait, rit Iterator @@ -756,15 +803,19 @@ func (e StdEng) Cbrt(a Tensor, opts ...FuncOpt) (retVal Tensor, err error) { } func (e StdEng) InvSqrt(a Tensor, opts ...FuncOpt) (retVal Tensor, err error) { - if err = unaryCheck(a, floatTypes); err != nil { + if err = unaryCheck(a, dtype.Floats); err != nil { err = errors.Wrapf(err, "InvSqrt failed") return } var reuse DenseTensor var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), true, opts...); err != nil { + var ctx context.Context + if ctx, reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be a noopError{}, no need to wrap. + } typ := a.Dtype().Type var ait, rit Iterator @@ -824,15 +875,19 @@ func (e StdEng) InvSqrt(a Tensor, opts ...FuncOpt) (retVal Tensor, err error) { } func (e StdEng) Abs(a Tensor, opts ...FuncOpt) (retVal Tensor, err error) { - if err = unaryCheck(a, signedTypes); err != nil { + if err = unaryCheck(a, dtype.Signed); err != nil { err = errors.Wrapf(err, "Abs failed") return } var reuse DenseTensor var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), true, opts...); err != nil { + var ctx context.Context + if ctx, reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be a noopError{}, no need to wrap. + } typ := a.Dtype().Type var ait, rit Iterator @@ -892,15 +947,19 @@ func (e StdEng) Abs(a Tensor, opts ...FuncOpt) (retVal Tensor, err error) { } func (e StdEng) Sign(a Tensor, opts ...FuncOpt) (retVal Tensor, err error) { - if err = unaryCheck(a, signedTypes); err != nil { + if err = unaryCheck(a, dtype.Signed); err != nil { err = errors.Wrapf(err, "Sign failed") return } var reuse DenseTensor var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), true, opts...); err != nil { + var ctx context.Context + if ctx, reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be a noopError{}, no need to wrap. + } typ := a.Dtype().Type var ait, rit Iterator diff --git a/defaultenginefloat32.go b/defaultenginefloat32.go index f260479..2b78aad 100644 --- a/defaultenginefloat32.go +++ b/defaultenginefloat32.go @@ -1,18 +1,20 @@ package tensor import ( - "unsafe" + "context" "github.com/pkg/errors" + "gorgonia.org/dtype" "gorgonia.org/tensor/internal/execution" "gorgonia.org/tensor/internal/storage" "gorgonia.org/vecf32" ) -func handleFuncOptsF32(expShape Shape, opts ...FuncOpt) (reuse DenseTensor, safe, toReuse, incr bool, err error) { +func handleFuncOptsF32(expShape Shape, o DataOrder, opts ...FuncOpt) (ctx context.Context, reuse DenseTensor, safe, toReuse, incr bool, err error) { fo := ParseFuncOpts(opts...) + ctx = fo.Context() reuseT, incr := fo.IncrReuse() safe = fo.Safe() toReuse = reuseT != nil @@ -30,6 +32,12 @@ func handleFuncOptsF32(expShape Shape, opts ...FuncOpt) (reuse DenseTensor, safe err = errors.Wrapf(err, "Cannot use reuse: shape mismatch") return } + + if !incr && reuse != nil { + reuse.setDataOrder(o) + // err = reuse.reshape(expShape...) + } + } returnOpOpt(fo) return @@ -108,19 +116,22 @@ type Float32Engine struct { } // makeArray allocates a slice for the array -func (e Float32Engine) makeArray(arr *array, t Dtype, size int) { +func (e Float32Engine) makeArray(arr *array, t dtype.Dtype, size int) { if t != Float32 { panic("Float32Engine only creates float32s") } - s := make([]float32, size) + if size < 0 { + panic("Cannot have negative sizes when making array") + } + arr.Header.Raw = make([]byte, size*4) arr.t = t - arr.L = size - arr.C = size - arr.Ptr = unsafe.Pointer(&s[0]) - arr.fix() } -func (e Float32Engine) FMA(a, x, y Tensor) (retVal Tensor, err error) { +func (e Float32Engine) FMA(ctx context.Context, a, x, y Tensor) (retVal Tensor, err error) { + if err = handleCtx(ctx); err != nil { + return nil, err + } + reuse := y if err = e.checkThree(a, x, reuse); err != nil { return nil, errors.Wrap(err, "Failed checks") @@ -143,7 +154,11 @@ func (e Float32Engine) FMA(a, x, y Tensor) (retVal Tensor, err error) { return } -func (e Float32Engine) FMAScalar(a Tensor, x interface{}, y Tensor) (retVal Tensor, err error) { +func (e Float32Engine) FMAScalar(ctx context.Context, a Tensor, x interface{}, y Tensor) (retVal Tensor, err error) { + if err = handleCtx(ctx); err != nil { + return nil, err + } + reuse := y if err = e.checkTwo(a, reuse); err != nil { return nil, errors.Wrap(err, "Failed checks") @@ -175,9 +190,14 @@ func (e Float32Engine) Add(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, var reuse DenseTensor var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, err = handleFuncOptsF32(a.Shape(), opts...); err != nil { + var ctx context.Context + if ctx, reuse, safe, toReuse, incr, err = handleFuncOptsF32(a.Shape(), a.DataOrder(), opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be noopError{}, no need to wrap. + } + if err = e.checkThree(a, b, reuse); err != nil { return nil, errors.Wrap(err, "Failed checks") } @@ -206,14 +226,21 @@ func (e Float32Engine) Add(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, vecf32.Add(dataA, dataB) retVal = a default: - ret := a.Clone().(headerer) - vecf32.Add(ret.hdr().Float32s(), dataB) + ret, ok := a.Clone().(float32ser) + if !ok { + return nil, errors.Errorf("Unable to get the Float32 data from `a`, of %T", a) + } + vecf32.Add(ret.Float32s(), dataB) retVal = ret.(Tensor) } return } -func (e Float32Engine) Inner(a, b Tensor) (retVal float32, err error) { +func (e Float32Engine) Inner(ctx context.Context, a, b Tensor) (retVal float32, err error) { + if err = handleCtx(ctx); err != nil { + return 0, err // this err will be noopError{}, no need to wrap. + } + var A, B []float32 var AD, BD *Dense var ok bool diff --git a/defaultenginefloat32_test.go b/defaultenginefloat32_test.go new file mode 100644 index 0000000..0ebd016 --- /dev/null +++ b/defaultenginefloat32_test.go @@ -0,0 +1,42 @@ +package tensor + +import ( + "testing" + "testing/quick" +) + +func TestFloat32Engine_makeArray(t *testing.T) { + + // the uint16 is just to make sure that tests are correctly run. + // we don't want the quicktest to randomly generate a size that is so large + // that Go takes a long time just to allocate. We'll test the other sizes (like negative numbers) + // after the quick test. + f := func(sz uint16) bool { + size := int(sz) + e := Float32Engine{StdEng{}} + dt := Float32 + arr := array{} + + e.makeArray(&arr, dt, size) + + if len(arr.Raw) != size*4 { + t.Errorf("Expected raw to be size*4. Got %v instead", len(arr.Raw)) + return false + } + v, ok := arr.Data().([]float32) + if !ok { + t.Errorf("Expected v to be []float32. Got %T instead", arr.Data()) + return false + } + + if len(v) != size { + return false + } + return true + } + + if err := quick.Check(f, nil); err != nil { + t.Errorf("Quick test failed %v", err) + } + +} diff --git a/defaultenginefloat64.go b/defaultenginefloat64.go index 6fe2786..85c59b2 100644 --- a/defaultenginefloat64.go +++ b/defaultenginefloat64.go @@ -1,18 +1,19 @@ package tensor import ( - "unsafe" + "context" "github.com/pkg/errors" + "gorgonia.org/dtype" "gorgonia.org/tensor/internal/execution" "gorgonia.org/tensor/internal/storage" "gorgonia.org/vecf64" ) -func handleFuncOptsF64(expShape Shape, opts ...FuncOpt) (reuse DenseTensor, safe, toReuse, incr bool, err error) { +func handleFuncOptsF64(expShape Shape, o DataOrder, opts ...FuncOpt) (ctx context.Context, reuse DenseTensor, safe, toReuse, incr bool, err error) { fo := ParseFuncOpts(opts...) - + ctx = fo.Context() reuseT, incr := fo.IncrReuse() safe = fo.Safe() toReuse = reuseT != nil @@ -30,6 +31,12 @@ func handleFuncOptsF64(expShape Shape, opts ...FuncOpt) (reuse DenseTensor, safe err = errors.Wrapf(err, "Cannot use reuse: shape mismatch") return } + + if !incr && reuse != nil { + reuse.setDataOrder(o) + // err = reuse.reshape(expShape...) + } + } returnOpOpt(fo) return @@ -108,19 +115,19 @@ type Float64Engine struct { } // makeArray allocates a slice for the array -func (e Float64Engine) makeArray(arr *array, t Dtype, size int) { +func (e Float64Engine) makeArray(arr *array, t dtype.Dtype, size int) { if t != Float64 { panic("Float64Engine only creates float64s") } - s := make([]float64, size) + arr.Header.Raw = make([]byte, size*8) arr.t = t - arr.L = size - arr.C = size - arr.Ptr = unsafe.Pointer(&s[0]) - arr.fix() } -func (e Float64Engine) FMA(a, x, y Tensor) (retVal Tensor, err error) { +func (e Float64Engine) FMA(ctx context.Context, a, x, y Tensor) (retVal Tensor, err error) { + if err = handleCtx(ctx); err != nil { + return nil, err + } + reuse := y if err = e.checkThree(a, x, reuse); err != nil { return nil, errors.Wrap(err, "Failed checks") @@ -143,7 +150,10 @@ func (e Float64Engine) FMA(a, x, y Tensor) (retVal Tensor, err error) { return } -func (e Float64Engine) FMAScalar(a Tensor, x interface{}, y Tensor) (retVal Tensor, err error) { +func (e Float64Engine) FMAScalar(ctx context.Context, a Tensor, x interface{}, y Tensor) (retVal Tensor, err error) { + if err = handleCtx(ctx); err != nil { + return nil, err + } reuse := y if err = e.checkTwo(a, reuse); err != nil { return nil, errors.Wrap(err, "Failed checks") @@ -175,9 +185,14 @@ func (e Float64Engine) Add(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, var reuse DenseTensor var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, err = handleFuncOptsF64(a.Shape(), opts...); err != nil { + var ctx context.Context + if ctx, reuse, safe, toReuse, incr, err = handleFuncOptsF64(a.Shape(), a.DataOrder(), opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be noopError{}, no need to wrap. + } + if err = e.checkThree(a, b, reuse); err != nil { return nil, errors.Wrap(err, "Failed checks") } @@ -206,14 +221,21 @@ func (e Float64Engine) Add(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, vecf64.Add(dataA, dataB) retVal = a default: - ret := a.Clone().(headerer) - vecf64.Add(ret.hdr().Float64s(), dataB) + ret, ok := a.Clone().(float64ser) + if !ok { + return nil, errors.Errorf("Unable to get the Float64 data from `a`, of %T", a) + } + vecf64.Add(ret.Float64s(), dataB) retVal = ret.(Tensor) } return } -func (e Float64Engine) Inner(a, b Tensor) (retVal float64, err error) { +func (e Float64Engine) Inner(ctx context.Context, a, b Tensor, opts ...FuncOpt) (retVal float64, err error) { + if err = handleCtx(ctx); err != nil { + return 0, err // this err will be noopError{}, no need to wrap. + } + var A, B []float64 var AD, BD *Dense var ok bool diff --git a/defaultenginefloat64_test.go b/defaultenginefloat64_test.go new file mode 100644 index 0000000..2d9391a --- /dev/null +++ b/defaultenginefloat64_test.go @@ -0,0 +1,42 @@ +package tensor + +import ( + "testing" + "testing/quick" +) + +func TestFloat64Engine_makeArray(t *testing.T) { + + // the uint16 is just to make sure that tests are correctly run. + // we don't want the quicktest to randomly generate a size that is so large + // that Go takes a long time just to allocate. We'll test the other sizes (like negative numbers) + // after the quick test. + f := func(sz uint16) bool { + size := int(sz) + e := Float64Engine{StdEng{}} + dt := Float64 + arr := array{} + + e.makeArray(&arr, dt, size) + + if len(arr.Raw) != size*8 { + t.Errorf("Expected raw to be size*8. Got %v instead", len(arr.Raw)) + return false + } + v, ok := arr.Data().([]float64) + if !ok { + t.Errorf("Expected v to be []float32. Got %T instead", arr.Data()) + return false + } + + if len(v) != size { + return false + } + return true + } + + if err := quick.Check(f, nil); err != nil { + t.Errorf("Quick test failed %v", err) + } + +} diff --git a/dense.go b/dense.go index 3152912..1623eee 100644 --- a/dense.go +++ b/dense.go @@ -2,9 +2,12 @@ package tensor import ( "fmt" + "reflect" "unsafe" "github.com/pkg/errors" + "gorgonia.org/dtype" + "gorgonia.org/tensor/internal/storage" ) const ( @@ -13,15 +16,15 @@ const ( // Dense represents a dense tensor - this is the most common form of tensors. It can be used to represent vectors, matrices.. etc type Dense struct { - *AP + AP array flag MemoryFlag e Engine // execution engine for the *Dense - oe standardEngine // optimized engine + oe StandardEngine // optimized engine // backup AP. When a transpose is done, the old *AP is backed up here, for easy untransposes - old *AP + old AP transposeWith []int // if viewOf != nil, then this *Dense is a view. @@ -32,11 +35,11 @@ type Dense struct { } // NewDense creates a new *Dense. It tries its best to get from the tensor pool. -func NewDense(dt Dtype, shape Shape, opts ...ConsOpt) *Dense { +func NewDense(dt dtype.Dtype, shape Shape, opts ...ConsOpt) *Dense { return recycledDense(dt, shape, opts...) } -func recycledDense(dt Dtype, shape Shape, opts ...ConsOpt) (retVal *Dense) { +func recycledDense(dt dtype.Dtype, shape Shape, opts ...ConsOpt) (retVal *Dense) { retVal = recycledDenseNoFix(dt, shape, opts...) retVal.fix() if err := retVal.sanity(); err != nil { @@ -45,16 +48,14 @@ func recycledDense(dt Dtype, shape Shape, opts ...ConsOpt) (retVal *Dense) { return } -func recycledDenseNoFix(dt Dtype, shape Shape, opts ...ConsOpt) (retVal *Dense) { - size := shape.TotalSize() - if shape.IsScalar() { - size = 1 - } +func recycledDenseNoFix(dt dtype.Dtype, shape Shape, opts ...ConsOpt) (retVal *Dense) { + // size := shape.TotalSize() + //if shape.IsScalar() { + // size = 1 + //} retVal = borrowDense() retVal.array.t = dt - retVal.array.L = size - retVal.array.C = size - retVal.AP = BorrowAP(shape.Dims()) + retVal.AP.zeroWithDims(shape.Dims()) for _, opt := range opts { opt(retVal) @@ -64,8 +65,7 @@ func recycledDenseNoFix(dt Dtype, shape Shape, opts ...ConsOpt) (retVal *Dense) } func (t *Dense) fromSlice(x interface{}) { - t.array.Ptr = nil - t.array.v = nil + t.array.Header.Raw = nil // GC anything else t.array.fromSlice(x) } @@ -78,44 +78,57 @@ func (t *Dense) addMask(mask []bool) { } func (t *Dense) makeArray(size int) { - if am, ok := t.e.(arrayMaker); ok { - am.makeArray(&t.array, t.t, size) + switch te := t.e.(type) { + case NonStdEngine: + t.flag = MakeMemoryFlag(t.flag, ManuallyManaged) + case arrayMaker: + te.makeArray(&t.array, t.t, size) return + case StandardEngine2: + default: + } - mem, err := t.e.Alloc(calcMemSize(t.t, size)) + memsize := calcMemSize(t.t, size) + mem, err := t.e.Alloc(memsize) if err != nil { panic(err) } - t.array.Ptr = mem.Pointer() - t.array.L = size - t.array.C = size - t.array.fix() + t.array.Raw = storage.FromMemory(mem.Uintptr(), uintptr(memsize)) return - } // Info returns the access pattern which explains how the data in the underlying array is accessed. This is mostly used for debugging. -func (t *Dense) Info() *AP { return t.AP } +func (t *Dense) Info() *AP { return &t.AP } // Dtype returns the data type of the *Dense tensor. -func (t *Dense) Dtype() Dtype { return t.t } +func (t *Dense) Dtype() dtype.Dtype { return t.t } // Data returns the underlying array. If the *Dense represents a scalar value, the scalar value is returned instead func (t *Dense) Data() interface{} { if t.IsScalar() { return t.Get(0) } - return t.v + + // build a type of []T + shdr := reflect.SliceHeader{ + Data: t.array.Uintptr(), + Len: t.array.Len(), + Cap: t.array.Cap(), + } + sliceT := reflect.SliceOf(t.t.Type) + ptr := unsafe.Pointer(&shdr) + val := reflect.Indirect(reflect.NewAt(sliceT, ptr)) + return val.Interface() } // DataSize returns the size of the underlying array. Typically t.DataSize() == t.Shape().TotalSize() func (t *Dense) DataSize() int { if t.IsScalar() { - return 0 + return 0 // DOUBLE CHECK } - return t.L + return t.array.Len() } // Engine returns the execution engine associated with this Tensor @@ -123,11 +136,15 @@ func (t *Dense) Engine() Engine { return t.e } // Reshape reshapes a *Dense. If the tensors need to be materialized (either it's a view or transpose), it will be materialized before the reshape happens func (t *Dense) Reshape(dims ...int) error { - if t.viewOf != 0 && t.o.isNotContiguous() { - return errors.Errorf(methodNYI, "Reshape", "non-contiguous views") + if t.Shape().TotalSize() != Shape(dims).TotalSize() { + return errors.Errorf("Cannot reshape %v into %v", t.Shape(), dims) } - if t.old != nil { + if t.viewOf != 0 && t.o.IsNotContiguous() { + return nyierr(methodNYI, "non-contiguous views") + } + + if !t.old.IsZero() { t.Transpose() } @@ -139,6 +156,20 @@ func (t *Dense) reshape(dims ...int) error { return t.sanity() } +func (t *Dense) unsqueeze(axis int) error { + if axis > t.shape.Dims()+1 { + return errors.Errorf("Cannot unsqueeze on axis %d when the tensor has shape %v", axis, t.shape) + } + t.shape = append(t.shape, 1) + copy(t.shape[axis+1:], t.shape[axis:]) + t.shape[axis] = 1 + + t.strides = append(t.strides, 1) + copy(t.strides[axis+1:], t.strides[axis:]) + + return nil +} + // ScalarValue returns the scalar value of a *Tensor, // IF and ONLY IF it's a Tensor representation of a scalar value. // This is required because operations like a (vec · vec) would return a scalar value. @@ -152,16 +183,6 @@ func (t *Dense) ScalarValue() interface{} { return t.Get(0) } -// IsView indicates if the Tensor is a view of another (typically from slicing) -func (t *Dense) IsView() bool { - return t.viewOf != 0 -} - -// IsMaterializeable indicates if the Tensor is materializable - if it has either gone through some transforms or slicing -func (t *Dense) IsMaterializable() bool { - return t.viewOf != 0 || t.old != nil -} - // IsManuallyManaged returns true if the memory associated with this *Dense is manually managed (by the user) func (t *Dense) IsManuallyManaged() bool { return t.flag.manuallyManaged() } @@ -172,19 +193,19 @@ func (t *Dense) IsNativelyAccessible() bool { return t.flag.nativelyAccessible() func (t *Dense) Clone() interface{} { if t.e != nil { retVal := new(Dense) - retVal.AP = t.AP.Clone() + t.AP.CloneTo(&retVal.AP) retVal.t = t.t retVal.e = t.e retVal.oe = t.oe retVal.flag = t.flag - retVal.makeArray(t.L) + retVal.makeArray(t.Len()) - if t.old != nil { + if !t.old.IsZero() { retVal.old = t.old.Clone() + t.old.CloneTo(&retVal.old) } copyDense(retVal, t) retVal.lock() - return retVal } panic("Unreachable: No engine") @@ -234,8 +255,8 @@ func (t *Dense) MaskFromDense(tts ...*Dense) { // Private methods -func (t *Dense) cap() int { return t.array.C } -func (t *Dense) len() int { return t.array.L } // exactly the same as DataSize +func (t *Dense) cap() int { return t.array.Cap() } +func (t *Dense) len() int { return t.array.Len() } // exactly the same as DataSize func (t *Dense) arr() array { return t.array } func (t *Dense) arrPtr() *array { return &t.array } @@ -246,32 +267,49 @@ func (t *Dense) setShape(s ...int) { return } -func (t *Dense) setAP(ap *AP) { t.AP = ap } +func (t *Dense) setAP(ap *AP) { t.AP = *ap } func (t *Dense) fix() { - if t.AP == nil { - return - } - if t.e == nil { t.e = StdEng{} } - if oe, ok := t.e.(standardEngine); ok { + if oe, ok := t.e.(StandardEngine); ok { t.oe = oe } + _, isNonStdEng := t.e.(NonStdEngine) + switch { - case t.IsScalar() && t.array.Ptr == nil: + case isNonStdEng && t.Shape() != nil: + // if there is already data in the array, we should back it up now + raw := t.array.Header.Raw + + // make the array + size := t.Shape().TotalSize() + if t.Shape().IsScalar() { + size = 1 + } + t.makeArray(size) + + if len(raw) != 0 { + // copy over if natively accessible + if t.IsNativelyAccessible() { + bs := t.byteSlice() + copy(bs, raw) + } + } + + case t.IsScalar() && t.array.Header.Raw == nil: t.makeArray(1) - case t.Shape() == nil && t.array.Ptr != nil: - size := t.L + case t.Shape() == nil && t.array.Header.Raw != nil: + size := t.Len() if size == 1 { t.SetShape() // scalar } else { t.SetShape(size) // vector } - case t.array.Ptr == nil && t.t != Dtype{}: + case t.array.Header.Raw == nil && t.t != dtype.Dtype{}: size := t.Shape().TotalSize() t.makeArray(size) @@ -298,31 +336,34 @@ func (t *Dense) makeMask() { // sanity is a function that sanity checks that a tensor is correct. func (t *Dense) sanity() error { - if t.AP != nil && t.Shape() == nil && t.array.Ptr == nil { + if !t.AP.IsZero() && t.Shape() == nil && t.array.Header.Raw == nil { return errors.New(emptyTensor) } - size := t.L + size := t.Len() expected := t.Size() if t.viewOf == 0 && size != expected && !t.IsScalar() { - return errors.Errorf(shapeMismatch, t.Shape(), size) + return errors.Wrap(errors.Errorf(shapeMismatch, t.Shape(), size), "sanity check failed") } + // TODO: sanity check for views return nil } -func (t *Dense) isTransposed() bool { return t.old == nil } +// isTransposed returns true if the *Dense holds a transposed array. +func (t *Dense) isTransposed() bool { return t.old.IsZero() } // oshape returns the original shape func (t *Dense) oshape() Shape { - if t.old != nil { + if !t.old.IsZero() { return t.old.Shape() } return t.Shape() } +// ostrides returns the original strides func (t *Dense) ostrides() []int { - if t.old != nil { + if !t.old.IsZero() { return t.old.Strides() } return t.Strides() @@ -333,21 +374,30 @@ func (t *Dense) ShallowClone() *Dense { retVal := borrowDense() retVal.e = t.e retVal.oe = t.oe - retVal.AP = t.AP.Clone() + t.AP.CloneTo(&retVal.AP) retVal.flag = t.flag retVal.array = t.array + + retVal.old = t.old + retVal.transposeWith = t.transposeWith + retVal.viewOf = t.viewOf + retVal.mask = t.mask + retVal.maskIsSoft = t.maskIsSoft return retVal } -func (t *Dense) oldAP() *AP { return t.old } -func (t *Dense) setOldAP(ap *AP) { t.old = ap } +func (t *Dense) oldAP() *AP { return &t.old } +func (t *Dense) setOldAP(ap *AP) { t.old = *ap } func (t *Dense) transposeAxes() []int { return t.transposeWith } + +//go:nocheckptr func (t *Dense) parentTensor() *Dense { if t.viewOf != 0 { return (*Dense)(unsafe.Pointer(t.viewOf)) } return nil } + func (t *Dense) setParentTensor(d *Dense) { if d == nil { t.viewOf = 0 @@ -536,8 +586,8 @@ func (t *Dense) Memset(x interface{}) error { if !t.IsNativelyAccessible() { return errors.Errorf(inaccessibleData, t) } - if t.IsMaterializable() { - it := NewFlatIterator(t.AP) + if t.RequiresIterator() { + it := newFlatIterator(&t.AP) return t.array.memsetIter(x, it) } return t.array.Memset(x) @@ -555,12 +605,20 @@ func (t *Dense) Eq(other interface{}) bool { return t.array.Eq(&ot.array) } + if ot, ok := other.(DenseTensor); ok { + if !t.Shape().Eq(ot.Shape()) { + return false + } + + return t.array.Eq(ot.arrPtr()) + } + return false } func (t *Dense) Zero() { - if t.IsMaterializable() { - it := NewFlatIterator(t.AP) + if t.RequiresIterator() { + it := newFlatIterator(&t.AP) if err := t.zeroIter(it); err != nil { panic(err) } @@ -584,12 +642,13 @@ func (t *Dense) slice(start, end int) { t.array = t.array.slice(start, end) } +// RequiresIterator indicates if an iterator is required to read the data in *Dense in the correct fashion func (t *Dense) RequiresIterator() bool { if t.len() == 1 { return false } // non continuous slice, transpose, or masked. If it's a slice and contiguous, then iterator is not required - if !t.o.isContiguous() || t.old != nil || t.IsMasked() { + if !t.o.IsContiguous() || !t.old.IsZero() || t.IsMasked() { return true } return false @@ -597,4 +656,4 @@ func (t *Dense) RequiresIterator() bool { func (t *Dense) Iterator() Iterator { return IteratorFromDense(t) } -func (t *Dense) standardEngine() standardEngine { return t.oe } +func (t *Dense) standardEngine() StandardEngine { return t.oe } diff --git a/dense_apply_test.go b/dense_apply_test.go index 5e8c23d..793f2c5 100644 --- a/dense_apply_test.go +++ b/dense_apply_test.go @@ -1,222 +1,224 @@ -package tensor - -import ( - "math/rand" - "testing" - "testing/quick" - "time" - "unsafe" -) - -func getMutateVal(dt Dtype) interface{} { - switch dt { - case Int: - return int(1) - case Int8: - return int8(1) - case Int16: - return int16(1) - case Int32: - return int32(1) - case Int64: - return int64(1) - case Uint: - return uint(1) - case Uint8: - return uint8(1) - case Uint16: - return uint16(1) - case Uint32: - return uint32(1) - case Uint64: - return uint64(1) - case Float32: - return float32(1) - case Float64: - return float64(1) - case Complex64: - var c complex64 = 1 - return c - case Complex128: - var c complex128 = 1 - return c - case Bool: - return true - case String: - return "Hello World" - case Uintptr: - return uintptr(0xdeadbeef) - case UnsafePointer: - return unsafe.Pointer(uintptr(0xdeadbeef)) - } - return nil -} - -func getMutateFn(dt Dtype) interface{} { - switch dt { - case Int: - return mutateI - case Int8: - return mutateI8 - case Int16: - return mutateI16 - case Int32: - return mutateI32 - case Int64: - return mutateI64 - case Uint: - return mutateU - case Uint8: - return mutateU8 - case Uint16: - return mutateU16 - case Uint32: - return mutateU32 - case Uint64: - return mutateU64 - case Float32: - return mutateF32 - case Float64: - return mutateF64 - case Complex64: - return mutateC64 - case Complex128: - return mutateC128 - case Bool: - return mutateB - case String: - return mutateStr - case Uintptr: - return mutateUintptr - case UnsafePointer: - return mutateUnsafePointer - } - return nil -} - -func TestDense_Apply(t *testing.T) { - var r *rand.Rand - mut := func(q *Dense) bool { - var mutVal interface{} - if mutVal = getMutateVal(q.Dtype()); mutVal == nil { - return true // we'll temporarily skip those we cannot mutate/get a mutation value - } - var fn interface{} - if fn = getMutateFn(q.Dtype()); fn == nil { - return true // we'll skip those that we cannot mutate - } - - we, eqFail := willerr(q, nil, nil) - _, ok := q.Engine().(Mapper) - we = we || !ok - - a := q.Clone().(*Dense) - correct := q.Clone().(*Dense) - correct.Memset(mutVal) - ret, err := a.Apply(fn) - if err, retEarly := qcErrCheck(t, "Apply", a, nil, we, err); retEarly { - if err != nil { - return false - } - return true - } - if !qcEqCheck(t, a.Dtype(), eqFail, correct.Data(), ret.Data()) { - return false - } - - // wrong fn type/illogical values - if _, err = a.Apply(getMutateFn); err == nil { - t.Error("Expected an error") - return false - } - return true - } - r = rand.New(rand.NewSource(time.Now().UnixNano())) - if err := quick.Check(mut, &quick.Config{Rand: r}); err != nil { - t.Errorf("Applying mutation function failed %v", err) - } -} - -func TestDense_Apply_unsafe(t *testing.T) { - var r *rand.Rand - mut := func(q *Dense) bool { - var mutVal interface{} - if mutVal = getMutateVal(q.Dtype()); mutVal == nil { - return true // we'll temporarily skip those we cannot mutate/get a mutation value - } - var fn interface{} - if fn = getMutateFn(q.Dtype()); fn == nil { - return true // we'll skip those that we cannot mutate - } - - we, eqFail := willerr(q, nil, nil) - _, ok := q.Engine().(Mapper) - we = we || !ok - - a := q.Clone().(*Dense) - correct := q.Clone().(*Dense) - correct.Memset(mutVal) - ret, err := a.Apply(fn, UseUnsafe()) - if err, retEarly := qcErrCheck(t, "Apply", a, nil, we, err); retEarly { - if err != nil { - return false - } - return true - } - if !qcEqCheck(t, a.Dtype(), eqFail, correct.Data(), ret.Data()) { - return false - } - if ret != a { - t.Error("Expected ret == correct (Unsafe option was used)") - return false - } - return true - } - r = rand.New(rand.NewSource(time.Now().UnixNano())) - if err := quick.Check(mut, &quick.Config{Rand: r}); err != nil { - t.Errorf("Applying mutation function failed %v", err) - } -} - -func TestDense_Apply_reuse(t *testing.T) { - var r *rand.Rand - mut := func(q *Dense) bool { - var mutVal interface{} - if mutVal = getMutateVal(q.Dtype()); mutVal == nil { - return true // we'll temporarily skip those we cannot mutate/get a mutation value - } - var fn interface{} - if fn = getMutateFn(q.Dtype()); fn == nil { - return true // we'll skip those that we cannot mutate - } - - we, eqFail := willerr(q, nil, nil) - _, ok := q.Engine().(Mapper) - we = we || !ok - - a := q.Clone().(*Dense) - reuse := q.Clone().(*Dense) - reuse.Zero() - correct := q.Clone().(*Dense) - correct.Memset(mutVal) - ret, err := a.Apply(fn, WithReuse(reuse)) - if err, retEarly := qcErrCheck(t, "Apply", a, nil, we, err); retEarly { - if err != nil { - return false - } - return true - } - if !qcEqCheck(t, a.Dtype(), eqFail, correct.Data(), ret.Data()) { - return false - } - if ret != reuse { - t.Error("Expected ret == correct (Unsafe option was used)") - return false - } - return true - } - r = rand.New(rand.NewSource(time.Now().UnixNano())) - if err := quick.Check(mut, &quick.Config{Rand: r}); err != nil { - t.Errorf("Applying mutation function failed %v", err) - } -} +package tensor + +import ( + "math/rand" + "testing" + "testing/quick" + "time" + "unsafe" + + "gorgonia.org/dtype" +) + +func getMutateVal(dt dtype.Dtype) interface{} { + switch dt { + case Int: + return int(1) + case Int8: + return int8(1) + case Int16: + return int16(1) + case Int32: + return int32(1) + case Int64: + return int64(1) + case Uint: + return uint(1) + case Uint8: + return uint8(1) + case Uint16: + return uint16(1) + case Uint32: + return uint32(1) + case Uint64: + return uint64(1) + case Float32: + return float32(1) + case Float64: + return float64(1) + case Complex64: + var c complex64 = 1 + return c + case Complex128: + var c complex128 = 1 + return c + case Bool: + return true + case String: + return "Hello World" + case Uintptr: + return uintptr(0xdeadbeef) + case UnsafePointer: + return unsafe.Pointer(uintptr(0xdeadbeef)) + } + return nil +} + +func getMutateFn(dt dtype.Dtype) interface{} { + switch dt { + case Int: + return mutateI + case Int8: + return mutateI8 + case Int16: + return mutateI16 + case Int32: + return mutateI32 + case Int64: + return mutateI64 + case Uint: + return mutateU + case Uint8: + return mutateU8 + case Uint16: + return mutateU16 + case Uint32: + return mutateU32 + case Uint64: + return mutateU64 + case Float32: + return mutateF32 + case Float64: + return mutateF64 + case Complex64: + return mutateC64 + case Complex128: + return mutateC128 + case Bool: + return mutateB + case String: + return mutateStr + case Uintptr: + return mutateUintptr + case UnsafePointer: + return mutateUnsafePointer + } + return nil +} + +func TestDense_Apply(t *testing.T) { + var r *rand.Rand + mut := func(q *Dense) bool { + var mutVal interface{} + if mutVal = getMutateVal(q.Dtype()); mutVal == nil { + return true // we'll temporarily skip those we cannot mutate/get a mutation value + } + var fn interface{} + if fn = getMutateFn(q.Dtype()); fn == nil { + return true // we'll skip those that we cannot mutate + } + + we, eqFail := willerr(q, nilTC, nilTC) + _, ok := q.Engine().(Mapper) + we = we || !ok + + a := q.Clone().(*Dense) + correct := q.Clone().(*Dense) + correct.Memset(mutVal) + ret, err := a.Apply(fn) + if err, retEarly := qcErrCheck(t, "Apply", a, nil, we, err); retEarly { + if err != nil { + return false + } + return true + } + if !qcEqCheck(t, a.Dtype(), eqFail, correct.Data(), ret.Data()) { + return false + } + + // wrong fn type/illogical values + if _, err = a.Apply(getMutateFn); err == nil { + t.Error("Expected an error") + return false + } + return true + } + r = rand.New(rand.NewSource(time.Now().UnixNano())) + if err := quick.Check(mut, &quick.Config{Rand: r}); err != nil { + t.Errorf("Applying mutation function failed %v", err) + } +} + +func TestDense_Apply_unsafe(t *testing.T) { + var r *rand.Rand + mut := func(q *Dense) bool { + var mutVal interface{} + if mutVal = getMutateVal(q.Dtype()); mutVal == nil { + return true // we'll temporarily skip those we cannot mutate/get a mutation value + } + var fn interface{} + if fn = getMutateFn(q.Dtype()); fn == nil { + return true // we'll skip those that we cannot mutate + } + + we, eqFail := willerr(q, nilTC, nilTC) + _, ok := q.Engine().(Mapper) + we = we || !ok + + a := q.Clone().(*Dense) + correct := q.Clone().(*Dense) + correct.Memset(mutVal) + ret, err := a.Apply(fn, UseUnsafe()) + if err, retEarly := qcErrCheck(t, "Apply", a, nil, we, err); retEarly { + if err != nil { + return false + } + return true + } + if !qcEqCheck(t, a.Dtype(), eqFail, correct.Data(), ret.Data()) { + return false + } + if ret != a { + t.Error("Expected ret == correct (Unsafe option was used)") + return false + } + return true + } + r = rand.New(rand.NewSource(time.Now().UnixNano())) + if err := quick.Check(mut, &quick.Config{Rand: r}); err != nil { + t.Errorf("Applying mutation function failed %v", err) + } +} + +func TestDense_Apply_reuse(t *testing.T) { + var r *rand.Rand + mut := func(q *Dense) bool { + var mutVal interface{} + if mutVal = getMutateVal(q.Dtype()); mutVal == nil { + return true // we'll temporarily skip those we cannot mutate/get a mutation value + } + var fn interface{} + if fn = getMutateFn(q.Dtype()); fn == nil { + return true // we'll skip those that we cannot mutate + } + + we, eqFail := willerr(q, nilTC, nilTC) + _, ok := q.Engine().(Mapper) + we = we || !ok + + a := q.Clone().(*Dense) + reuse := q.Clone().(*Dense) + reuse.Zero() + correct := q.Clone().(*Dense) + correct.Memset(mutVal) + ret, err := a.Apply(fn, WithReuse(reuse)) + if err, retEarly := qcErrCheck(t, "Apply", a, nil, we, err); retEarly { + if err != nil { + return false + } + return true + } + if !qcEqCheck(t, a.Dtype(), eqFail, correct.Data(), ret.Data()) { + return false + } + if ret != reuse { + t.Error("Expected ret == correct (Unsafe option was used)") + return false + } + return true + } + r = rand.New(rand.NewSource(time.Now().UnixNano())) + if err := quick.Check(mut, &quick.Config{Rand: r}); err != nil { + t.Errorf("Applying mutation function failed %v", err) + } +} diff --git a/dense_argmethods.go b/dense_argmethods.go index bfdc0d7..fdace5f 100644 --- a/dense_argmethods.go +++ b/dense_argmethods.go @@ -7,13 +7,14 @@ import "github.com/pkg/errors" // Argmax finds the index of the max value along the axis provided func (t *Dense) Argmax(axis int) (retVal *Dense, err error) { e := t.e + ctx := ctxFromEngine(e) switch am := e.(type) { case denseArgmaxer: - return am.argmaxDenseTensor(t, axis) + return am.argmaxDenseTensor(ctx, t, axis) case Argmaxer: var ret Tensor var ok bool - if ret, err = am.Argmax(t, axis); err != nil { + if ret, err = am.Argmax(ctx, t, axis); err != nil { return nil, errors.Wrapf(err, opFail, "Argmax") } if retVal, ok = ret.(*Dense); !ok { @@ -29,13 +30,14 @@ func (t *Dense) Argmax(axis int) (retVal *Dense, err error) { // Argmin finds the index of the min value along the axis provided func (t *Dense) Argmin(axis int) (retVal *Dense, err error) { e := t.e + ctx := ctxFromEngine(e) switch am := e.(type) { case denseArgminer: - return am.argminDenseTensor(t, axis) + return am.argminDenseTensor(ctx, t, axis) case Argminer: var ret Tensor var ok bool - if ret, err = am.Argmin(t, axis); err != nil { + if ret, err = am.Argmin(ctx, t, axis); err != nil { return nil, errors.Wrapf(err, opFail, "Argmax") } if retVal, ok = ret.(*Dense); !ok { diff --git a/dense_argmethods_test.go b/dense_argmethods_test.go index a4b03bd..a90b957 100644 --- a/dense_argmethods_test.go +++ b/dense_argmethods_test.go @@ -1,5 +1,3 @@ -// Code generated by genlib2. DO NOT EDIT. - package tensor import ( @@ -10,6 +8,8 @@ import ( "github.com/stretchr/testify/assert" ) +// Code generated by genlib2. DO NOT EDIT. + /* Test data */ var basicDenseI = New(WithShape(2, 3, 4, 5, 2), WithBacking([]int{3, 4, 2, 4, 3, 8, 3, 9, 7, 4, 3, 0, 3, 9, 9, 0, 6, 7, 3, 9, 4, 8, 5, 1, 1, 9, 4, 0, 4, 1, 6, 6, 4, 9, 3, 8, 1, 7, 0, 7, 4, 0, 6, 8, 2, 8, 0, 6, 1, 6, 2, 3, 7, 5, 7, 3, 0, 8, 6, 5, 6, 9, 7, 5, 6, 8, 7, 9, 5, 0, 8, 1, 4, 0, 6, 6, 3, 3, 8, 1, 1, 3, 2, 5, 9, 0, 4, 5, 3, 1, 9, 1, 9, 3, 9, 3, 3, 4, 5, 9, 4, 2, 2, 7, 9, 8, 1, 6, 9, 4, 4, 1, 8, 9, 8, 0, 9, 9, 4, 6, 7, 5, 9, 9, 4, 8, 5, 8, 2, 4, 8, 2, 7, 2, 8, 7, 2, 3, 7, 0, 9, 9, 8, 9, 2, 1, 7, 0, 7, 9, 0, 2, 4, 8, 7, 9, 6, 8, 3, 3, 7, 2, 9, 2, 8, 2, 3, 6, 0, 8, 7, 7, 0, 9, 0, 9, 3, 2, 6, 9, 5, 8, 6, 9, 5, 6, 1, 8, 7, 8, 1, 9, 9, 3, 7, 7, 6, 8, 2, 1, 1, 5, 1, 4, 0, 5, 1, 7, 9, 5, 6, 6, 8, 7, 5, 1, 3, 4, 0, 1, 8, 0, 2, 6, 9, 1, 4, 8, 0, 5, 6, 2, 9, 4, 4, 2, 4, 4, 4, 3})) diff --git a/dense_arith.go b/dense_arith.go index 7218d37..5c4eba9 100644 --- a/dense_arith.go +++ b/dense_arith.go @@ -1,9 +1,9 @@ -// Code generated by genlib2. DO NOT EDIT. - package tensor import "github.com/pkg/errors" +// Code generated by genlib2. DO NOT EDIT. + // Add performs t + other elementwise. Both t and other must have the same shape. // Acceptable FuncOpts are: UseUnsafe(), WithReuse(T), WithIncr(T) func (t *Dense) Add(other *Dense, opts ...FuncOpt) (retVal *Dense, err error) { diff --git a/dense_arith_test.go b/dense_arith_test.go index 039db67..423fc85 100644 --- a/dense_arith_test.go +++ b/dense_arith_test.go @@ -1,17 +1,19 @@ -// Code generated by genlib2. DO NOT EDIT. - package tensor import ( "testing" "testing/quick" + + "gorgonia.org/dtype" ) +// Code generated by genlib2. DO NOT EDIT. + func TestDense_Add(t *testing.T) { iden := func(a *Dense) bool { b := New(Of(a.t), WithShape(a.Shape().Clone()...), WithEngine(a.Engine())) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := a.Engine().(Adder) we = we || !ok @@ -37,7 +39,7 @@ func TestDense_Sub(t *testing.T) { inv := func(a *Dense) bool { b := New(Of(a.t), WithShape(a.Shape().Clone()...), WithEngine(a.Engine())) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := a.Engine().(Suber) we = we || !ok @@ -64,7 +66,7 @@ func TestDense_Mul(t *testing.T) { b := New(Of(a.t), WithShape(a.Shape().Clone()...), WithEngine(a.Engine())) b.Memset(identityVal(1, a.t)) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := a.Engine().(Muler) we = we || !ok @@ -91,7 +93,7 @@ func TestDense_Div(t *testing.T) { b := New(Of(a.t), WithShape(a.Shape().Clone()...), WithEngine(a.Engine())) b.Memset(identityVal(1, a.t)) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := a.Engine().(Diver) we = we || !ok @@ -118,7 +120,7 @@ func TestDense_Pow(t *testing.T) { b := New(Of(a.t), WithShape(a.Shape().Clone()...), WithEngine(a.Engine())) b.Memset(identityVal(1, a.t)) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, floatcmplxTypes, complexTypes) + we, willFailEq := willerr(a, dtype.FloatComplex, dtype.Complexes) _, ok := a.Engine().(Power) we = we || !ok @@ -144,7 +146,7 @@ func TestDense_Add_unsafe(t *testing.T) { iden := func(a *Dense) bool { b := New(Of(a.t), WithShape(a.Shape().Clone()...), WithEngine(a.Engine())) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := a.Engine().(Adder) we = we || !ok @@ -163,7 +165,6 @@ func TestDense_Add_unsafe(t *testing.T) { t.Errorf("Expected ret to be the same as a") return false } - return true } if err := quick.Check(iden, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { @@ -175,7 +176,7 @@ func TestDense_Sub_unsafe(t *testing.T) { inv := func(a *Dense) bool { b := New(Of(a.t), WithShape(a.Shape().Clone()...), WithEngine(a.Engine())) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := a.Engine().(Suber) we = we || !ok @@ -195,7 +196,6 @@ func TestDense_Sub_unsafe(t *testing.T) { t.Errorf("Expected ret to be the same as a") return false } - return true } if err := quick.Check(inv, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { @@ -207,7 +207,7 @@ func TestDense_Mul_unsafe(t *testing.T) { b := New(Of(a.t), WithShape(a.Shape().Clone()...), WithEngine(a.Engine())) b.Memset(identityVal(1, a.t)) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := a.Engine().(Muler) we = we || !ok @@ -226,7 +226,6 @@ func TestDense_Mul_unsafe(t *testing.T) { t.Errorf("Expected ret to be the same as a") return false } - return true } if err := quick.Check(iden, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { @@ -239,7 +238,7 @@ func TestDense_Div_unsafe(t *testing.T) { b := New(Of(a.t), WithShape(a.Shape().Clone()...), WithEngine(a.Engine())) b.Memset(identityVal(1, a.t)) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := a.Engine().(Diver) we = we || !ok @@ -259,7 +258,6 @@ func TestDense_Div_unsafe(t *testing.T) { t.Errorf("Expected ret to be the same as a") return false } - return true } if err := quick.Check(inv, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { @@ -271,7 +269,7 @@ func TestDense_Pow_unsafe(t *testing.T) { b := New(Of(a.t), WithShape(a.Shape().Clone()...), WithEngine(a.Engine())) b.Memset(identityVal(1, a.t)) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, floatcmplxTypes, complexTypes) + we, willFailEq := willerr(a, dtype.FloatComplex, dtype.Complexes) _, ok := a.Engine().(Power) we = we || !ok @@ -290,7 +288,6 @@ func TestDense_Pow_unsafe(t *testing.T) { t.Errorf("Expected ret to be the same as a") return false } - return true } if err := quick.Check(iden, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { @@ -303,7 +300,7 @@ func TestDense_Add_reuse(t *testing.T) { b := New(Of(a.t), WithShape(a.Shape().Clone()...), WithEngine(a.Engine())) reuse := New(Of(a.t), WithShape(a.Shape().Clone()...)) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := a.Engine().(Adder) we = we || !ok @@ -329,13 +326,65 @@ func TestDense_Add_reuse(t *testing.T) { t.Errorf("Identity test for Add failed: %v", err) } + mut := func(a, b *Dense, reuseA bool) bool { + // req because we're only testing on one kind of tensor/engine combo + a.e = StdEng{} + a.oe = StdEng{} + a.flag = 0 + b.e = StdEng{} + b.oe = StdEng{} + b.flag = 0 + + if a.Dtype() != b.Dtype() { + return true + } + if !a.Shape().Eq(b.Shape()) { + return true + } + + correct, err := a.Add(b) + we, willFailEq := willerr(a, dtype.Number, nilTC) + _, ok := a.Engine().(Adder) + we = we || !ok + + var ret, reuse *Dense + if reuseA { + ret, err = a.Add(b, WithReuse(a)) + reuse = a + } else { + ret, err = a.Add(b, WithReuse(b)) + reuse = b + } + + if err, retEarly := qcErrCheck(t, "Add", a, b, we, err); retEarly { + if err != nil { + return false + } + return true + } + + if !qcEqCheck(t, a.Dtype(), willFailEq, correct.Data(), ret.Data()) { + return false + } + + if reuse != ret { + t.Errorf("Expected reuse to be the same as retVal") + return false + } + + return true + } + if err := quick.Check(mut, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { + t.Errorf("Reuse Mutation test for Add failed: %v", err) + } + } func TestDense_Sub_reuse(t *testing.T) { inv := func(a *Dense) bool { b := New(Of(a.t), WithShape(a.Shape().Clone()...), WithEngine(a.Engine())) reuse := New(Of(a.t), WithShape(a.Shape().Clone()...)) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := a.Engine().(Suber) we = we || !ok @@ -361,6 +410,58 @@ func TestDense_Sub_reuse(t *testing.T) { if err := quick.Check(inv, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { t.Errorf("Inv test for Sub failed: %v", err) } + mut := func(a, b *Dense, reuseA bool) bool { + // req because we're only testing on one kind of tensor/engine combo + a.e = StdEng{} + a.oe = StdEng{} + a.flag = 0 + b.e = StdEng{} + b.oe = StdEng{} + b.flag = 0 + + if a.Dtype() != b.Dtype() { + return true + } + if !a.Shape().Eq(b.Shape()) { + return true + } + + correct, err := a.Sub(b) + we, willFailEq := willerr(a, dtype.Number, nilTC) + _, ok := a.Engine().(Suber) + we = we || !ok + + var ret, reuse *Dense + if reuseA { + ret, err = a.Sub(b, WithReuse(a)) + reuse = a + } else { + ret, err = a.Sub(b, WithReuse(b)) + reuse = b + } + + if err, retEarly := qcErrCheck(t, "Sub", a, b, we, err); retEarly { + if err != nil { + return false + } + return true + } + + if !qcEqCheck(t, a.Dtype(), willFailEq, correct.Data(), ret.Data()) { + return false + } + + if reuse != ret { + t.Errorf("Expected reuse to be the same as retVal") + return false + } + + return true + } + if err := quick.Check(mut, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { + t.Errorf("Reuse Mutation test for Sub failed: %v", err) + } + } func TestDense_Mul_reuse(t *testing.T) { iden := func(a *Dense) bool { @@ -368,7 +469,7 @@ func TestDense_Mul_reuse(t *testing.T) { b.Memset(identityVal(1, a.t)) reuse := New(Of(a.t), WithShape(a.Shape().Clone()...)) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := a.Engine().(Muler) we = we || !ok @@ -394,6 +495,58 @@ func TestDense_Mul_reuse(t *testing.T) { t.Errorf("Identity test for Mul failed: %v", err) } + mut := func(a, b *Dense, reuseA bool) bool { + // req because we're only testing on one kind of tensor/engine combo + a.e = StdEng{} + a.oe = StdEng{} + a.flag = 0 + b.e = StdEng{} + b.oe = StdEng{} + b.flag = 0 + + if a.Dtype() != b.Dtype() { + return true + } + if !a.Shape().Eq(b.Shape()) { + return true + } + + correct, err := a.Mul(b) + we, willFailEq := willerr(a, dtype.Number, nilTC) + _, ok := a.Engine().(Muler) + we = we || !ok + + var ret, reuse *Dense + if reuseA { + ret, err = a.Mul(b, WithReuse(a)) + reuse = a + } else { + ret, err = a.Mul(b, WithReuse(b)) + reuse = b + } + + if err, retEarly := qcErrCheck(t, "Mul", a, b, we, err); retEarly { + if err != nil { + return false + } + return true + } + + if !qcEqCheck(t, a.Dtype(), willFailEq, correct.Data(), ret.Data()) { + return false + } + + if reuse != ret { + t.Errorf("Expected reuse to be the same as retVal") + return false + } + + return true + } + if err := quick.Check(mut, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { + t.Errorf("Reuse Mutation test for Mul failed: %v", err) + } + } func TestDense_Div_reuse(t *testing.T) { inv := func(a *Dense) bool { @@ -401,7 +554,7 @@ func TestDense_Div_reuse(t *testing.T) { b.Memset(identityVal(1, a.t)) reuse := New(Of(a.t), WithShape(a.Shape().Clone()...)) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := a.Engine().(Diver) we = we || !ok @@ -427,6 +580,58 @@ func TestDense_Div_reuse(t *testing.T) { if err := quick.Check(inv, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { t.Errorf("Inv test for Div failed: %v", err) } + mut := func(a, b *Dense, reuseA bool) bool { + // req because we're only testing on one kind of tensor/engine combo + a.e = StdEng{} + a.oe = StdEng{} + a.flag = 0 + b.e = StdEng{} + b.oe = StdEng{} + b.flag = 0 + + if a.Dtype() != b.Dtype() { + return true + } + if !a.Shape().Eq(b.Shape()) { + return true + } + + correct, err := a.Div(b) + we, willFailEq := willerr(a, dtype.Number, nilTC) + _, ok := a.Engine().(Diver) + we = we || !ok + + var ret, reuse *Dense + if reuseA { + ret, err = a.Div(b, WithReuse(a)) + reuse = a + } else { + ret, err = a.Div(b, WithReuse(b)) + reuse = b + } + + if err, retEarly := qcErrCheck(t, "Div", a, b, we, err); retEarly { + if err != nil { + return false + } + return true + } + + if !qcEqCheck(t, a.Dtype(), willFailEq, correct.Data(), ret.Data()) { + return false + } + + if reuse != ret { + t.Errorf("Expected reuse to be the same as retVal") + return false + } + + return true + } + if err := quick.Check(mut, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { + t.Errorf("Reuse Mutation test for Div failed: %v", err) + } + } func TestDense_Pow_reuse(t *testing.T) { iden := func(a *Dense) bool { @@ -434,7 +639,7 @@ func TestDense_Pow_reuse(t *testing.T) { b.Memset(identityVal(1, a.t)) reuse := New(Of(a.t), WithShape(a.Shape().Clone()...)) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, floatcmplxTypes, complexTypes) + we, willFailEq := willerr(a, dtype.FloatComplex, dtype.Complexes) _, ok := a.Engine().(Power) we = we || !ok @@ -468,7 +673,7 @@ func TestDense_Add_incr(t *testing.T) { correct := a.Clone().(*Dense) incr.Memset(identityVal(100, a.t)) correct.Add(incr, UseUnsafe()) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := a.Engine().(Adder) we = we || !ok @@ -497,7 +702,7 @@ func TestDense_Sub_incr(t *testing.T) { correct := a.Clone().(*Dense) incr.Memset(identityVal(100, a.t)) correct.Add(incr, UseUnsafe()) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := a.Engine().(Suber) we = we || !ok @@ -527,7 +732,7 @@ func TestDense_Mul_incr(t *testing.T) { correct := a.Clone().(*Dense) incr.Memset(identityVal(100, a.t)) correct.Add(incr, UseUnsafe()) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := a.Engine().(Muler) we = we || !ok @@ -557,7 +762,7 @@ func TestDense_Div_incr(t *testing.T) { correct := a.Clone().(*Dense) incr.Memset(identityVal(100, a.t)) correct.Add(incr, UseUnsafe()) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := a.Engine().(Diver) we = we || !ok @@ -587,7 +792,7 @@ func TestDense_Pow_incr(t *testing.T) { correct := a.Clone().(*Dense) incr.Memset(identityVal(100, a.t)) correct.Add(incr, UseUnsafe()) - we, willFailEq := willerr(a, floatcmplxTypes, complexTypes) + we, willFailEq := willerr(a, dtype.FloatComplex, dtype.Complexes) _, ok := a.Engine().(Power) we = we || !ok @@ -615,7 +820,7 @@ func TestDense_AddScalar(t *testing.T) { b := identityVal(0, q.t) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Adder) we = we || !ok @@ -641,7 +846,7 @@ func TestDense_AddScalar(t *testing.T) { a := q.Clone().(*Dense) b := identityVal(0, q.t) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Adder) we = we || !ok @@ -662,6 +867,32 @@ func TestDense_AddScalar(t *testing.T) { t.Errorf("Identity test for Add (scalar as left, tensor as right) failed: %v", err) } + type Foo int + wt1 := func(a *Dense) bool { + b := Foo(0) + ret, err := Add(a, b) + if err == nil { + return false + } + _ = ret + return true + } + if err := quick.Check(wt1, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { + t.Errorf("WrongType test for Add (tensor as left, scalar as right) failed: %v", err) + } + + wt2 := func(a *Dense) bool { + b := Foo(0) + ret, err := Add(b, a) + if err == nil { + return false + } + _ = ret + return true + } + if err := quick.Check(wt2, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { + t.Errorf("WrongType test for Add (tensor as right, scalar as left) failed: %v", err) + } } func TestDense_SubScalar(t *testing.T) { inv1 := func(q *Dense) bool { @@ -669,7 +900,7 @@ func TestDense_SubScalar(t *testing.T) { b := identityVal(0, q.t) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, unsignedTypes) + we, willFailEq := willerr(a, dtype.Number, dtype.Unsigned) _, ok := q.Engine().(Suber) we = we || !ok @@ -695,7 +926,7 @@ func TestDense_SubScalar(t *testing.T) { a := q.Clone().(*Dense) b := identityVal(0, q.t) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, unsignedTypes) + we, willFailEq := willerr(a, dtype.Number, dtype.Unsigned) _, ok := q.Engine().(Suber) we = we || !ok @@ -716,6 +947,32 @@ func TestDense_SubScalar(t *testing.T) { if err := quick.Check(inv2, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { t.Errorf("Inv test for Sub (scalar as left, tensor as right) failed: %v", err) } + type Foo int + wt1 := func(a *Dense) bool { + b := Foo(0) + ret, err := Sub(a, b) + if err == nil { + return false + } + _ = ret + return true + } + if err := quick.Check(wt1, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { + t.Errorf("WrongType test for Sub (tensor as left, scalar as right) failed: %v", err) + } + + wt2 := func(a *Dense) bool { + b := Foo(0) + ret, err := Sub(b, a) + if err == nil { + return false + } + _ = ret + return true + } + if err := quick.Check(wt2, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { + t.Errorf("WrongType test for Sub (tensor as right, scalar as left) failed: %v", err) + } } func TestDense_MulScalar(t *testing.T) { iden1 := func(q *Dense) bool { @@ -723,7 +980,7 @@ func TestDense_MulScalar(t *testing.T) { b := identityVal(1, q.t) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Muler) we = we || !ok @@ -749,7 +1006,7 @@ func TestDense_MulScalar(t *testing.T) { a := q.Clone().(*Dense) b := identityVal(1, q.t) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Muler) we = we || !ok @@ -770,6 +1027,32 @@ func TestDense_MulScalar(t *testing.T) { t.Errorf("Identity test for Mul (scalar as left, tensor as right) failed: %v", err) } + type Foo int + wt1 := func(a *Dense) bool { + b := Foo(0) + ret, err := Mul(a, b) + if err == nil { + return false + } + _ = ret + return true + } + if err := quick.Check(wt1, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { + t.Errorf("WrongType test for Mul (tensor as left, scalar as right) failed: %v", err) + } + + wt2 := func(a *Dense) bool { + b := Foo(0) + ret, err := Mul(b, a) + if err == nil { + return false + } + _ = ret + return true + } + if err := quick.Check(wt2, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { + t.Errorf("WrongType test for Mul (tensor as right, scalar as left) failed: %v", err) + } } func TestDense_DivScalar(t *testing.T) { inv1 := func(q *Dense) bool { @@ -777,7 +1060,7 @@ func TestDense_DivScalar(t *testing.T) { b := identityVal(1, q.t) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Diver) we = we || !ok @@ -799,6 +1082,32 @@ func TestDense_DivScalar(t *testing.T) { t.Errorf("Inv test for Div (tensor as left, scalar as right) failed: %v", err) } + type Foo int + wt1 := func(a *Dense) bool { + b := Foo(0) + ret, err := Div(a, b) + if err == nil { + return false + } + _ = ret + return true + } + if err := quick.Check(wt1, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { + t.Errorf("WrongType test for Div (tensor as left, scalar as right) failed: %v", err) + } + + wt2 := func(a *Dense) bool { + b := Foo(0) + ret, err := Div(b, a) + if err == nil { + return false + } + _ = ret + return true + } + if err := quick.Check(wt2, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { + t.Errorf("WrongType test for Div (tensor as right, scalar as left) failed: %v", err) + } } func TestDense_PowScalar(t *testing.T) { iden1 := func(q *Dense) bool { @@ -806,7 +1115,7 @@ func TestDense_PowScalar(t *testing.T) { b := identityVal(1, q.t) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, floatcmplxTypes, complexTypes) + we, willFailEq := willerr(a, dtype.FloatComplex, dtype.Complexes) _, ok := q.Engine().(Power) we = we || !ok @@ -828,6 +1137,32 @@ func TestDense_PowScalar(t *testing.T) { t.Errorf("Identity test for Pow (tensor as left, scalar as right) failed: %v", err) } + type Foo int + wt1 := func(a *Dense) bool { + b := Foo(0) + ret, err := Pow(a, b) + if err == nil { + return false + } + _ = ret + return true + } + if err := quick.Check(wt1, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { + t.Errorf("WrongType test for Pow (tensor as left, scalar as right) failed: %v", err) + } + + wt2 := func(a *Dense) bool { + b := Foo(0) + ret, err := Pow(b, a) + if err == nil { + return false + } + _ = ret + return true + } + if err := quick.Check(wt2, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { + t.Errorf("WrongType test for Pow (tensor as right, scalar as left) failed: %v", err) + } } func TestDense_AddScalar_unsafe(t *testing.T) { iden1 := func(q *Dense) bool { @@ -835,7 +1170,7 @@ func TestDense_AddScalar_unsafe(t *testing.T) { b := identityVal(0, q.t) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Adder) we = we || !ok @@ -854,7 +1189,6 @@ func TestDense_AddScalar_unsafe(t *testing.T) { t.Errorf("Expected ret to be the same as a") return false } - return true } @@ -866,7 +1200,7 @@ func TestDense_AddScalar_unsafe(t *testing.T) { a := q.Clone().(*Dense) b := identityVal(0, q.t) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Adder) we = we || !ok @@ -885,7 +1219,6 @@ func TestDense_AddScalar_unsafe(t *testing.T) { t.Errorf("Expected ret to be the same as a") return false } - return true } if err := quick.Check(iden2, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { @@ -899,7 +1232,7 @@ func TestDense_SubScalar_unsafe(t *testing.T) { b := identityVal(0, q.t) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, unsignedTypes) + we, willFailEq := willerr(a, dtype.Number, dtype.Unsigned) _, ok := q.Engine().(Suber) we = we || !ok @@ -919,7 +1252,6 @@ func TestDense_SubScalar_unsafe(t *testing.T) { t.Errorf("Expected ret to be the same as a") return false } - return true } if err := quick.Check(inv1, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { @@ -930,7 +1262,7 @@ func TestDense_SubScalar_unsafe(t *testing.T) { a := q.Clone().(*Dense) b := identityVal(0, q.t) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, unsignedTypes) + we, willFailEq := willerr(a, dtype.Number, dtype.Unsigned) _, ok := q.Engine().(Suber) we = we || !ok @@ -950,7 +1282,6 @@ func TestDense_SubScalar_unsafe(t *testing.T) { t.Errorf("Expected ret to be the same as a") return false } - return true } if err := quick.Check(inv2, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { @@ -963,7 +1294,7 @@ func TestDense_MulScalar_unsafe(t *testing.T) { b := identityVal(1, q.t) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Muler) we = we || !ok @@ -982,7 +1313,6 @@ func TestDense_MulScalar_unsafe(t *testing.T) { t.Errorf("Expected ret to be the same as a") return false } - return true } @@ -994,7 +1324,7 @@ func TestDense_MulScalar_unsafe(t *testing.T) { a := q.Clone().(*Dense) b := identityVal(1, q.t) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Muler) we = we || !ok @@ -1013,7 +1343,6 @@ func TestDense_MulScalar_unsafe(t *testing.T) { t.Errorf("Expected ret to be the same as a") return false } - return true } if err := quick.Check(iden2, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { @@ -1027,7 +1356,7 @@ func TestDense_DivScalar_unsafe(t *testing.T) { b := identityVal(1, q.t) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Diver) we = we || !ok @@ -1047,7 +1376,6 @@ func TestDense_DivScalar_unsafe(t *testing.T) { t.Errorf("Expected ret to be the same as a") return false } - return true } if err := quick.Check(inv1, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { @@ -1061,7 +1389,7 @@ func TestDense_PowScalar_unsafe(t *testing.T) { b := identityVal(1, q.t) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, floatcmplxTypes, complexTypes) + we, willFailEq := willerr(a, dtype.FloatComplex, dtype.Complexes) _, ok := q.Engine().(Power) we = we || !ok @@ -1080,7 +1408,6 @@ func TestDense_PowScalar_unsafe(t *testing.T) { t.Errorf("Expected ret to be the same as a") return false } - return true } @@ -1096,7 +1423,7 @@ func TestDense_AddScalar_reuse(t *testing.T) { reuse := New(Of(a.t), WithShape(a.Shape().Clone()...)) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Adder) we = we || !ok @@ -1128,7 +1455,7 @@ func TestDense_AddScalar_reuse(t *testing.T) { b := identityVal(0, q.t) reuse := New(Of(a.t), WithShape(a.Shape().Clone()...)) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Adder) we = we || !ok @@ -1154,6 +1481,58 @@ func TestDense_AddScalar_reuse(t *testing.T) { t.Errorf("Identity test for Add (scalar as left, tensor as right) failed: %v", err) } + mut := func(a, b *Dense, reuseA bool) bool { + // req because we're only testing on one kind of tensor/engine combo + a.e = StdEng{} + a.oe = StdEng{} + a.flag = 0 + b.e = StdEng{} + b.oe = StdEng{} + b.flag = 0 + + if a.Dtype() != b.Dtype() { + return true + } + if !a.Shape().Eq(b.Shape()) { + return true + } + + correct, err := a.Add(b) + we, willFailEq := willerr(a, dtype.Number, nilTC) + _, ok := a.Engine().(Adder) + we = we || !ok + + var ret, reuse *Dense + if reuseA { + ret, err = a.Add(b, WithReuse(a)) + reuse = a + } else { + ret, err = a.Add(b, WithReuse(b)) + reuse = b + } + + if err, retEarly := qcErrCheck(t, "Add", a, b, we, err); retEarly { + if err != nil { + return false + } + return true + } + + if !qcEqCheck(t, a.Dtype(), willFailEq, correct.Data(), ret.Data()) { + return false + } + + if reuse != ret { + t.Errorf("Expected reuse to be the same as retVal") + return false + } + + return true + } + if err := quick.Check(mut, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { + t.Errorf("Reuse Mutation test for Add failed: %v", err) + } + } func TestDense_SubScalar_reuse(t *testing.T) { inv1 := func(q *Dense) bool { @@ -1162,7 +1541,7 @@ func TestDense_SubScalar_reuse(t *testing.T) { reuse := New(Of(a.t), WithShape(a.Shape().Clone()...)) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, unsignedTypes) + we, willFailEq := willerr(a, dtype.Number, dtype.Unsigned) _, ok := q.Engine().(Suber) we = we || !ok @@ -1194,7 +1573,7 @@ func TestDense_SubScalar_reuse(t *testing.T) { b := identityVal(0, q.t) reuse := New(Of(a.t), WithShape(a.Shape().Clone()...)) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, unsignedTypes) + we, willFailEq := willerr(a, dtype.Number, dtype.Unsigned) _, ok := q.Engine().(Suber) we = we || !ok @@ -1220,6 +1599,58 @@ func TestDense_SubScalar_reuse(t *testing.T) { if err := quick.Check(inv2, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { t.Errorf("Inv test for Sub (scalar as left, tensor as right) failed: %v", err) } + mut := func(a, b *Dense, reuseA bool) bool { + // req because we're only testing on one kind of tensor/engine combo + a.e = StdEng{} + a.oe = StdEng{} + a.flag = 0 + b.e = StdEng{} + b.oe = StdEng{} + b.flag = 0 + + if a.Dtype() != b.Dtype() { + return true + } + if !a.Shape().Eq(b.Shape()) { + return true + } + + correct, err := a.Sub(b) + we, willFailEq := willerr(a, dtype.Number, dtype.Unsigned) + _, ok := a.Engine().(Suber) + we = we || !ok + + var ret, reuse *Dense + if reuseA { + ret, err = a.Sub(b, WithReuse(a)) + reuse = a + } else { + ret, err = a.Sub(b, WithReuse(b)) + reuse = b + } + + if err, retEarly := qcErrCheck(t, "Sub", a, b, we, err); retEarly { + if err != nil { + return false + } + return true + } + + if !qcEqCheck(t, a.Dtype(), willFailEq, correct.Data(), ret.Data()) { + return false + } + + if reuse != ret { + t.Errorf("Expected reuse to be the same as retVal") + return false + } + + return true + } + if err := quick.Check(mut, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { + t.Errorf("Reuse Mutation test for Sub failed: %v", err) + } + } func TestDense_MulScalar_reuse(t *testing.T) { iden1 := func(q *Dense) bool { @@ -1228,7 +1659,7 @@ func TestDense_MulScalar_reuse(t *testing.T) { reuse := New(Of(a.t), WithShape(a.Shape().Clone()...)) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Muler) we = we || !ok @@ -1260,7 +1691,7 @@ func TestDense_MulScalar_reuse(t *testing.T) { b := identityVal(1, q.t) reuse := New(Of(a.t), WithShape(a.Shape().Clone()...)) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Muler) we = we || !ok @@ -1286,6 +1717,58 @@ func TestDense_MulScalar_reuse(t *testing.T) { t.Errorf("Identity test for Mul (scalar as left, tensor as right) failed: %v", err) } + mut := func(a, b *Dense, reuseA bool) bool { + // req because we're only testing on one kind of tensor/engine combo + a.e = StdEng{} + a.oe = StdEng{} + a.flag = 0 + b.e = StdEng{} + b.oe = StdEng{} + b.flag = 0 + + if a.Dtype() != b.Dtype() { + return true + } + if !a.Shape().Eq(b.Shape()) { + return true + } + + correct, err := a.Mul(b) + we, willFailEq := willerr(a, dtype.Number, nilTC) + _, ok := a.Engine().(Muler) + we = we || !ok + + var ret, reuse *Dense + if reuseA { + ret, err = a.Mul(b, WithReuse(a)) + reuse = a + } else { + ret, err = a.Mul(b, WithReuse(b)) + reuse = b + } + + if err, retEarly := qcErrCheck(t, "Mul", a, b, we, err); retEarly { + if err != nil { + return false + } + return true + } + + if !qcEqCheck(t, a.Dtype(), willFailEq, correct.Data(), ret.Data()) { + return false + } + + if reuse != ret { + t.Errorf("Expected reuse to be the same as retVal") + return false + } + + return true + } + if err := quick.Check(mut, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { + t.Errorf("Reuse Mutation test for Mul failed: %v", err) + } + } func TestDense_DivScalar_reuse(t *testing.T) { inv1 := func(q *Dense) bool { @@ -1294,7 +1777,7 @@ func TestDense_DivScalar_reuse(t *testing.T) { reuse := New(Of(a.t), WithShape(a.Shape().Clone()...)) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Diver) we = we || !ok @@ -1321,6 +1804,58 @@ func TestDense_DivScalar_reuse(t *testing.T) { t.Errorf("Inv test for Div (tensor as left, scalar as right) failed: %v", err) } + mut := func(a, b *Dense, reuseA bool) bool { + // req because we're only testing on one kind of tensor/engine combo + a.e = StdEng{} + a.oe = StdEng{} + a.flag = 0 + b.e = StdEng{} + b.oe = StdEng{} + b.flag = 0 + + if a.Dtype() != b.Dtype() { + return true + } + if !a.Shape().Eq(b.Shape()) { + return true + } + + correct, err := a.Div(b) + we, willFailEq := willerr(a, dtype.Number, nilTC) + _, ok := a.Engine().(Diver) + we = we || !ok + + var ret, reuse *Dense + if reuseA { + ret, err = a.Div(b, WithReuse(a)) + reuse = a + } else { + ret, err = a.Div(b, WithReuse(b)) + reuse = b + } + + if err, retEarly := qcErrCheck(t, "Div", a, b, we, err); retEarly { + if err != nil { + return false + } + return true + } + + if !qcEqCheck(t, a.Dtype(), willFailEq, correct.Data(), ret.Data()) { + return false + } + + if reuse != ret { + t.Errorf("Expected reuse to be the same as retVal") + return false + } + + return true + } + if err := quick.Check(mut, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { + t.Errorf("Reuse Mutation test for Div failed: %v", err) + } + } func TestDense_PowScalar_reuse(t *testing.T) { iden1 := func(q *Dense) bool { @@ -1329,7 +1864,7 @@ func TestDense_PowScalar_reuse(t *testing.T) { reuse := New(Of(a.t), WithShape(a.Shape().Clone()...)) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, floatcmplxTypes, complexTypes) + we, willFailEq := willerr(a, dtype.FloatComplex, dtype.Complexes) _, ok := q.Engine().(Power) we = we || !ok @@ -1366,7 +1901,7 @@ func TestDense_AddScalar_incr(t *testing.T) { correct := a.Clone().(*Dense) incr.Memset(identityVal(100, a.t)) correct.Add(incr, UseUnsafe()) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Adder) we = we || !ok @@ -1395,7 +1930,7 @@ func TestDense_AddScalar_incr(t *testing.T) { correct := a.Clone().(*Dense) incr.Memset(identityVal(100, a.t)) correct.Add(incr, UseUnsafe()) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Adder) we = we || !ok @@ -1426,7 +1961,7 @@ func TestDense_SubScalar_incr(t *testing.T) { correct := a.Clone().(*Dense) incr.Memset(identityVal(100, a.t)) correct.Add(incr, UseUnsafe()) - we, willFailEq := willerr(a, numberTypes, unsignedTypes) + we, willFailEq := willerr(a, dtype.Number, dtype.Unsigned) _, ok := q.Engine().(Suber) we = we || !ok @@ -1458,7 +1993,7 @@ func TestDense_MulScalar_incr(t *testing.T) { correct := a.Clone().(*Dense) incr.Memset(identityVal(100, a.t)) correct.Add(incr, UseUnsafe()) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Muler) we = we || !ok @@ -1487,7 +2022,7 @@ func TestDense_MulScalar_incr(t *testing.T) { correct := a.Clone().(*Dense) incr.Memset(identityVal(100, a.t)) correct.Add(incr, UseUnsafe()) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Muler) we = we || !ok @@ -1518,7 +2053,7 @@ func TestDense_DivScalar_incr(t *testing.T) { correct := a.Clone().(*Dense) incr.Memset(identityVal(100, a.t)) correct.Add(incr, UseUnsafe()) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Diver) we = we || !ok @@ -1550,7 +2085,7 @@ func TestDense_PowScalar_incr(t *testing.T) { correct := a.Clone().(*Dense) incr.Memset(identityVal(100, a.t)) correct.Add(incr, UseUnsafe()) - we, willFailEq := willerr(a, floatcmplxTypes, complexTypes) + we, willFailEq := willerr(a, dtype.FloatComplex, dtype.Complexes) _, ok := q.Engine().(Power) we = we || !ok diff --git a/dense_assign.go b/dense_assign.go index 0fdc1d4..5f44897 100644 --- a/dense_assign.go +++ b/dense_assign.go @@ -1,6 +1,8 @@ package tensor -import "github.com/pkg/errors" +import ( + "github.com/pkg/errors" +) func overlaps(a, b DenseTensor) bool { if a.cap() == 0 || b.cap() == 0 { @@ -8,14 +10,14 @@ func overlaps(a, b DenseTensor) bool { } aarr := a.arr() barr := b.arr() - if aarr.Ptr == barr.Ptr { + if aarr.Uintptr() == barr.Uintptr() { return true } - aptr := uintptr(aarr.Ptr) - bptr := uintptr(barr.Ptr) + aptr := aarr.Uintptr() + bptr := barr.Uintptr() - capA := aptr + uintptr(aarr.C)*a.Dtype().Size() - capB := bptr + uintptr(barr.C)*b.Dtype().Size() + capA := aptr + uintptr(cap(aarr.Header.Raw)) + capB := bptr + uintptr(cap(barr.Header.Raw)) switch { case aptr < bptr: @@ -84,11 +86,11 @@ func assignArray(dest, src DenseTensor) (err error) { return } dap := dest.Info() - sap := NewAP(tmpShape, newStrides) - sap.o = src.Info().o + sap := MakeAP(tmpShape, newStrides, src.Info().o, src.Info().Δ) - diter := NewFlatIterator(dap) - siter := NewFlatIterator(sap) + diter := newFlatIterator(dap) + siter := newFlatIterator(&sap) _, err = copyDenseIter(dest, src, diter, siter) + sap.zeroOnly() // cleanup, but not entirely because tmpShape and tmpStrides are separately cleaned up. Don't double free return } diff --git a/dense_cmp.go b/dense_cmp.go index 4ffaadf..d7770ac 100644 --- a/dense_cmp.go +++ b/dense_cmp.go @@ -1,9 +1,9 @@ -// Code generated by genlib2. DO NOT EDIT. - package tensor import "github.com/pkg/errors" +// Code generated by genlib2. DO NOT EDIT. + // Gt performs t > other elementwise. Both t and other must have the same shape. // Acceptable FuncOpts are: UseUnsafe(), AsSameType(), WithReuse(). //UseUnsafe() will ensure that the same type is returned. diff --git a/dense_cmp_test.go b/dense_cmp_test.go index 4c1db8e..82e8518 100644 --- a/dense_cmp_test.go +++ b/dense_cmp_test.go @@ -1,16 +1,18 @@ -// Code generated by genlib2. DO NOT EDIT. - package tensor import ( "reflect" "testing" "testing/quick" + + "gorgonia.org/dtype" ) +// Code generated by genlib2. DO NOT EDIT. + func TestDense_Gt(t *testing.T) { transFn := func(q *Dense) bool { - we, _ := willerr(q, ordTypes, nil) + we, _ := willerr(q, dtype.Ord, nilTC) _, ok := q.Engine().(Gter) we = we || !ok @@ -62,13 +64,13 @@ func TestDense_Gt(t *testing.T) { return true } if err := quick.Check(transFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Transitivity test for Gt failed: %v", err) + t.Errorf("Transitivity test for Gt failed: %v", err) } } func TestDense_Gte(t *testing.T) { transFn := func(q *Dense) bool { - we, _ := willerr(q, ordTypes, nil) + we, _ := willerr(q, dtype.Ord, nilTC) _, ok := q.Engine().(Gteer) we = we || !ok @@ -120,13 +122,13 @@ func TestDense_Gte(t *testing.T) { return true } if err := quick.Check(transFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Transitivity test for Gte failed: %v", err) + t.Errorf("Transitivity test for Gte failed: %v", err) } } func TestDense_Lt(t *testing.T) { transFn := func(q *Dense) bool { - we, _ := willerr(q, ordTypes, nil) + we, _ := willerr(q, dtype.Ord, nilTC) _, ok := q.Engine().(Lter) we = we || !ok @@ -178,13 +180,13 @@ func TestDense_Lt(t *testing.T) { return true } if err := quick.Check(transFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Transitivity test for Lt failed: %v", err) + t.Errorf("Transitivity test for Lt failed: %v", err) } } func TestDense_Lte(t *testing.T) { transFn := func(q *Dense) bool { - we, _ := willerr(q, ordTypes, nil) + we, _ := willerr(q, dtype.Ord, nilTC) _, ok := q.Engine().(Lteer) we = we || !ok @@ -236,13 +238,13 @@ func TestDense_Lte(t *testing.T) { return true } if err := quick.Check(transFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Transitivity test for Lte failed: %v", err) + t.Errorf("Transitivity test for Lte failed: %v", err) } } func TestDense_ElEq(t *testing.T) { transFn := func(q *Dense) bool { - we, _ := willerr(q, eqTypes, nil) + we, _ := willerr(q, dtype.Eq, nilTC) _, ok := q.Engine().(ElEqer) we = we || !ok @@ -294,11 +296,11 @@ func TestDense_ElEq(t *testing.T) { return true } if err := quick.Check(transFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Transitivity test for ElEq failed: %v", err) + t.Errorf("Transitivity test for ElEq failed: %v", err) } symFn := func(q *Dense) bool { - we, _ := willerr(q, eqTypes, nil) + we, _ := willerr(q, dtype.Eq, nilTC) _, ok := q.Engine().(ElEqer) we = we || !ok @@ -328,12 +330,12 @@ func TestDense_ElEq(t *testing.T) { } if err := quick.Check(symFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Transitivity test for ElEq failed: %v", err) + t.Errorf("Transitivity test for ElEq failed: %v", err) } } func TestDense_ElNe(t *testing.T) { symFn := func(q *Dense) bool { - we, _ := willerr(q, eqTypes, nil) + we, _ := willerr(q, dtype.Eq, nilTC) _, ok := q.Engine().(ElEqer) we = we || !ok @@ -363,16 +365,16 @@ func TestDense_ElNe(t *testing.T) { } if err := quick.Check(symFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Transitivity test for ElNe failed: %v", err) + t.Errorf("Transitivity test for ElNe failed: %v", err) } } func TestDense_Gt_assame(t *testing.T) { transFn := func(q *Dense) bool { - we, _ := willerr(q, nonComplexNumberTypes, nil) + we, _ := willerr(q, dtype.NonComplexNumber, nilTC) _, ok := q.Engine().(Gter) we = we || !ok - if err := typeclassCheck(q.Dtype(), nonComplexNumberTypes); err != nil { + if err := dtype.TypeClassCheck(q.Dtype(), dtype.NonComplexNumber); err != nil { return true // we exit early if the generated type is not something we can handle } r := newRand() @@ -422,17 +424,17 @@ func TestDense_Gt_assame(t *testing.T) { return true } if err := quick.Check(transFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Transitivity test for Gt failed: %v", err) + t.Errorf("Transitivity test for Gt failed: %v", err) } } func TestDense_Gte_assame(t *testing.T) { transFn := func(q *Dense) bool { - we, _ := willerr(q, nonComplexNumberTypes, nil) + we, _ := willerr(q, dtype.NonComplexNumber, nilTC) _, ok := q.Engine().(Gteer) we = we || !ok - if err := typeclassCheck(q.Dtype(), nonComplexNumberTypes); err != nil { + if err := dtype.TypeClassCheck(q.Dtype(), dtype.NonComplexNumber); err != nil { return true // we exit early if the generated type is not something we can handle } r := newRand() @@ -482,17 +484,17 @@ func TestDense_Gte_assame(t *testing.T) { return true } if err := quick.Check(transFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Transitivity test for Gte failed: %v", err) + t.Errorf("Transitivity test for Gte failed: %v", err) } } func TestDense_Lt_assame(t *testing.T) { transFn := func(q *Dense) bool { - we, _ := willerr(q, nonComplexNumberTypes, nil) + we, _ := willerr(q, dtype.NonComplexNumber, nilTC) _, ok := q.Engine().(Lter) we = we || !ok - if err := typeclassCheck(q.Dtype(), nonComplexNumberTypes); err != nil { + if err := dtype.TypeClassCheck(q.Dtype(), dtype.NonComplexNumber); err != nil { return true // we exit early if the generated type is not something we can handle } r := newRand() @@ -542,17 +544,17 @@ func TestDense_Lt_assame(t *testing.T) { return true } if err := quick.Check(transFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Transitivity test for Lt failed: %v", err) + t.Errorf("Transitivity test for Lt failed: %v", err) } } func TestDense_Lte_assame(t *testing.T) { transFn := func(q *Dense) bool { - we, _ := willerr(q, nonComplexNumberTypes, nil) + we, _ := willerr(q, dtype.NonComplexNumber, nilTC) _, ok := q.Engine().(Lteer) we = we || !ok - if err := typeclassCheck(q.Dtype(), nonComplexNumberTypes); err != nil { + if err := dtype.TypeClassCheck(q.Dtype(), dtype.NonComplexNumber); err != nil { return true // we exit early if the generated type is not something we can handle } r := newRand() @@ -602,17 +604,17 @@ func TestDense_Lte_assame(t *testing.T) { return true } if err := quick.Check(transFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Transitivity test for Lte failed: %v", err) + t.Errorf("Transitivity test for Lte failed: %v", err) } } func TestDense_ElEq_assame(t *testing.T) { transFn := func(q *Dense) bool { - we, _ := willerr(q, nonComplexNumberTypes, nil) + we, _ := willerr(q, dtype.NonComplexNumber, nilTC) _, ok := q.Engine().(ElEqer) we = we || !ok - if err := typeclassCheck(q.Dtype(), nonComplexNumberTypes); err != nil { + if err := dtype.TypeClassCheck(q.Dtype(), dtype.NonComplexNumber); err != nil { return true // we exit early if the generated type is not something we can handle } r := newRand() @@ -662,15 +664,15 @@ func TestDense_ElEq_assame(t *testing.T) { return true } if err := quick.Check(transFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Transitivity test for ElEq failed: %v", err) + t.Errorf("Transitivity test for ElEq failed: %v", err) } symFn := func(q *Dense) bool { - we, _ := willerr(q, nonComplexNumberTypes, nil) + we, _ := willerr(q, dtype.NonComplexNumber, nilTC) _, ok := q.Engine().(ElEqer) we = we || !ok - if err := typeclassCheck(q.Dtype(), nonComplexNumberTypes); err != nil { + if err := dtype.TypeClassCheck(q.Dtype(), dtype.NonComplexNumber); err != nil { return true // we exit early if the generated type is not something we can handle } r := newRand() @@ -699,16 +701,16 @@ func TestDense_ElEq_assame(t *testing.T) { } if err := quick.Check(symFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Transitivity test for ElEq failed: %v", err) + t.Errorf("Transitivity test for ElEq failed: %v", err) } } func TestDense_ElNe_assame(t *testing.T) { symFn := func(q *Dense) bool { - we, _ := willerr(q, nonComplexNumberTypes, nil) + we, _ := willerr(q, dtype.NonComplexNumber, nilTC) _, ok := q.Engine().(ElEqer) we = we || !ok - if err := typeclassCheck(q.Dtype(), nonComplexNumberTypes); err != nil { + if err := dtype.TypeClassCheck(q.Dtype(), dtype.NonComplexNumber); err != nil { return true // we exit early if the generated type is not something we can handle } r := newRand() @@ -737,12 +739,12 @@ func TestDense_ElNe_assame(t *testing.T) { } if err := quick.Check(symFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Transitivity test for ElNe failed: %v", err) + t.Errorf("Transitivity test for ElNe failed: %v", err) } } func TestDense_GtScalar(t *testing.T) { transFn := func(q *Dense) bool { - we, _ := willerr(q, ordTypes, nil) + we, _ := willerr(q, dtype.Ord, nilTC) _, ok := q.Engine().(Gter) we = we || !ok @@ -792,13 +794,13 @@ func TestDense_GtScalar(t *testing.T) { return true } if err := quick.Check(transFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Transitivity test for Gt failed: %v", err) + t.Errorf("Transitivity test for Gt failed: %v", err) } } func TestDense_GteScalar(t *testing.T) { transFn := func(q *Dense) bool { - we, _ := willerr(q, ordTypes, nil) + we, _ := willerr(q, dtype.Ord, nilTC) _, ok := q.Engine().(Gteer) we = we || !ok @@ -848,13 +850,13 @@ func TestDense_GteScalar(t *testing.T) { return true } if err := quick.Check(transFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Transitivity test for Gte failed: %v", err) + t.Errorf("Transitivity test for Gte failed: %v", err) } } func TestDense_LtScalar(t *testing.T) { transFn := func(q *Dense) bool { - we, _ := willerr(q, ordTypes, nil) + we, _ := willerr(q, dtype.Ord, nilTC) _, ok := q.Engine().(Lter) we = we || !ok @@ -904,13 +906,13 @@ func TestDense_LtScalar(t *testing.T) { return true } if err := quick.Check(transFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Transitivity test for Lt failed: %v", err) + t.Errorf("Transitivity test for Lt failed: %v", err) } } func TestDense_LteScalar(t *testing.T) { transFn := func(q *Dense) bool { - we, _ := willerr(q, ordTypes, nil) + we, _ := willerr(q, dtype.Ord, nilTC) _, ok := q.Engine().(Lteer) we = we || !ok @@ -960,13 +962,13 @@ func TestDense_LteScalar(t *testing.T) { return true } if err := quick.Check(transFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Transitivity test for Lte failed: %v", err) + t.Errorf("Transitivity test for Lte failed: %v", err) } } func TestDense_ElEqScalar(t *testing.T) { transFn := func(q *Dense) bool { - we, _ := willerr(q, eqTypes, nil) + we, _ := willerr(q, dtype.Eq, nilTC) _, ok := q.Engine().(ElEqer) we = we || !ok @@ -1016,11 +1018,11 @@ func TestDense_ElEqScalar(t *testing.T) { return true } if err := quick.Check(transFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Transitivity test for ElEq failed: %v", err) + t.Errorf("Transitivity test for ElEq failed: %v", err) } symFn := func(q *Dense) bool { - we, _ := willerr(q, eqTypes, nil) + we, _ := willerr(q, dtype.Eq, nilTC) _, ok := q.Engine().(ElEqer) we = we || !ok @@ -1048,12 +1050,12 @@ func TestDense_ElEqScalar(t *testing.T) { } if err := quick.Check(symFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Symmetry test for ElEq failed: %v", err) + t.Errorf("Symmetry test for ElEq failed: %v", err) } } func TestDense_ElNeScalar(t *testing.T) { symFn := func(q *Dense) bool { - we, _ := willerr(q, eqTypes, nil) + we, _ := willerr(q, dtype.Eq, nilTC) _, ok := q.Engine().(ElEqer) we = we || !ok @@ -1081,16 +1083,16 @@ func TestDense_ElNeScalar(t *testing.T) { } if err := quick.Check(symFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Symmetry test for ElNe failed: %v", err) + t.Errorf("Symmetry test for ElNe failed: %v", err) } } func TestDense_GtScalar_assame(t *testing.T) { transFn := func(q *Dense) bool { - we, _ := willerr(q, nonComplexNumberTypes, nil) + we, _ := willerr(q, dtype.NonComplexNumber, nilTC) _, ok := q.Engine().(Gter) we = we || !ok - if err := typeclassCheck(q.Dtype(), nonComplexNumberTypes); err != nil { + if err := dtype.TypeClassCheck(q.Dtype(), dtype.NonComplexNumber); err != nil { return true // we exit early if the generated type is not something we can handle } r := newRand() @@ -1138,17 +1140,17 @@ func TestDense_GtScalar_assame(t *testing.T) { return true } if err := quick.Check(transFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Transitivity test for Gt failed: %v", err) + t.Errorf("Transitivity test for Gt failed: %v", err) } } func TestDense_GteScalar_assame(t *testing.T) { transFn := func(q *Dense) bool { - we, _ := willerr(q, nonComplexNumberTypes, nil) + we, _ := willerr(q, dtype.NonComplexNumber, nilTC) _, ok := q.Engine().(Gteer) we = we || !ok - if err := typeclassCheck(q.Dtype(), nonComplexNumberTypes); err != nil { + if err := dtype.TypeClassCheck(q.Dtype(), dtype.NonComplexNumber); err != nil { return true // we exit early if the generated type is not something we can handle } r := newRand() @@ -1196,17 +1198,17 @@ func TestDense_GteScalar_assame(t *testing.T) { return true } if err := quick.Check(transFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Transitivity test for Gte failed: %v", err) + t.Errorf("Transitivity test for Gte failed: %v", err) } } func TestDense_LtScalar_assame(t *testing.T) { transFn := func(q *Dense) bool { - we, _ := willerr(q, nonComplexNumberTypes, nil) + we, _ := willerr(q, dtype.NonComplexNumber, nilTC) _, ok := q.Engine().(Lter) we = we || !ok - if err := typeclassCheck(q.Dtype(), nonComplexNumberTypes); err != nil { + if err := dtype.TypeClassCheck(q.Dtype(), dtype.NonComplexNumber); err != nil { return true // we exit early if the generated type is not something we can handle } r := newRand() @@ -1254,17 +1256,17 @@ func TestDense_LtScalar_assame(t *testing.T) { return true } if err := quick.Check(transFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Transitivity test for Lt failed: %v", err) + t.Errorf("Transitivity test for Lt failed: %v", err) } } func TestDense_LteScalar_assame(t *testing.T) { transFn := func(q *Dense) bool { - we, _ := willerr(q, nonComplexNumberTypes, nil) + we, _ := willerr(q, dtype.NonComplexNumber, nilTC) _, ok := q.Engine().(Lteer) we = we || !ok - if err := typeclassCheck(q.Dtype(), nonComplexNumberTypes); err != nil { + if err := dtype.TypeClassCheck(q.Dtype(), dtype.NonComplexNumber); err != nil { return true // we exit early if the generated type is not something we can handle } r := newRand() @@ -1312,17 +1314,17 @@ func TestDense_LteScalar_assame(t *testing.T) { return true } if err := quick.Check(transFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Transitivity test for Lte failed: %v", err) + t.Errorf("Transitivity test for Lte failed: %v", err) } } func TestDense_ElEqScalar_assame(t *testing.T) { transFn := func(q *Dense) bool { - we, _ := willerr(q, nonComplexNumberTypes, nil) + we, _ := willerr(q, dtype.NonComplexNumber, nilTC) _, ok := q.Engine().(ElEqer) we = we || !ok - if err := typeclassCheck(q.Dtype(), nonComplexNumberTypes); err != nil { + if err := dtype.TypeClassCheck(q.Dtype(), dtype.NonComplexNumber); err != nil { return true // we exit early if the generated type is not something we can handle } r := newRand() @@ -1370,15 +1372,15 @@ func TestDense_ElEqScalar_assame(t *testing.T) { return true } if err := quick.Check(transFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Transitivity test for ElEq failed: %v", err) + t.Errorf("Transitivity test for ElEq failed: %v", err) } symFn := func(q *Dense) bool { - we, _ := willerr(q, nonComplexNumberTypes, nil) + we, _ := willerr(q, dtype.NonComplexNumber, nilTC) _, ok := q.Engine().(ElEqer) we = we || !ok - if err := typeclassCheck(q.Dtype(), nonComplexNumberTypes); err != nil { + if err := dtype.TypeClassCheck(q.Dtype(), dtype.NonComplexNumber); err != nil { return true // we exit early if the generated type is not something we can handle } r := newRand() @@ -1405,16 +1407,16 @@ func TestDense_ElEqScalar_assame(t *testing.T) { } if err := quick.Check(symFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Symmetry test for ElEq failed: %v", err) + t.Errorf("Symmetry test for ElEq failed: %v", err) } } func TestDense_ElNeScalar_assame(t *testing.T) { symFn := func(q *Dense) bool { - we, _ := willerr(q, nonComplexNumberTypes, nil) + we, _ := willerr(q, dtype.NonComplexNumber, nilTC) _, ok := q.Engine().(ElEqer) we = we || !ok - if err := typeclassCheck(q.Dtype(), nonComplexNumberTypes); err != nil { + if err := dtype.TypeClassCheck(q.Dtype(), dtype.NonComplexNumber); err != nil { return true // we exit early if the generated type is not something we can handle } r := newRand() @@ -1441,6 +1443,6 @@ func TestDense_ElNeScalar_assame(t *testing.T) { } if err := quick.Check(symFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Symmetry test for ElNe failed: %v", err) + t.Errorf("Symmetry test for ElNe failed: %v", err) } } diff --git a/dense_colmajor_linalg_test.go b/dense_colmajor_linalg_test.go new file mode 100644 index 0000000..feccfc5 --- /dev/null +++ b/dense_colmajor_linalg_test.go @@ -0,0 +1,483 @@ +package tensor + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +var colMajorTraceTests = []struct { + data interface{} + + correct interface{} + err bool +}{ + {[]int{0, 1, 2, 3, 4, 5}, int(4), false}, + {[]int8{0, 1, 2, 3, 4, 5}, int8(4), false}, + {[]int16{0, 1, 2, 3, 4, 5}, int16(4), false}, + {[]int32{0, 1, 2, 3, 4, 5}, int32(4), false}, + {[]int64{0, 1, 2, 3, 4, 5}, int64(4), false}, + {[]uint{0, 1, 2, 3, 4, 5}, uint(4), false}, + {[]uint8{0, 1, 2, 3, 4, 5}, uint8(4), false}, + {[]uint16{0, 1, 2, 3, 4, 5}, uint16(4), false}, + {[]uint32{0, 1, 2, 3, 4, 5}, uint32(4), false}, + {[]uint64{0, 1, 2, 3, 4, 5}, uint64(4), false}, + {[]float32{0, 1, 2, 3, 4, 5}, float32(4), false}, + {[]float64{0, 1, 2, 3, 4, 5}, float64(4), false}, + {[]complex64{0, 1, 2, 3, 4, 5}, complex64(4), false}, + {[]complex128{0, 1, 2, 3, 4, 5}, complex128(4), false}, + {[]bool{true, false, true, false, true, false}, nil, true}, +} + +func TestColMajor_Dense_Trace(t *testing.T) { + assert := assert.New(t) + for i, tts := range colMajorTraceTests { + T := New(WithShape(2, 3), AsFortran(tts.data)) + trace, err := T.Trace() + + if checkErr(t, tts.err, err, "Trace", i) { + continue + } + assert.Equal(tts.correct, trace) + + // + T = New(WithBacking(tts.data)) + _, err = T.Trace() + if err == nil { + t.Error("Expected an error when Trace() on non-matrices") + } + } +} + +var colMajorInnerTests = []struct { + a, b interface{} + shapeA, shapeB Shape + + correct interface{} + err bool +}{ + {Range(Float64, 0, 3), Range(Float64, 0, 3), Shape{3}, Shape{3}, float64(5), false}, + {Range(Float64, 0, 3), Range(Float64, 0, 3), Shape{3, 1}, Shape{3}, float64(5), false}, + {Range(Float64, 0, 3), Range(Float64, 0, 3), Shape{1, 3}, Shape{3}, float64(5), false}, + {Range(Float64, 0, 3), Range(Float64, 0, 3), Shape{3, 1}, Shape{3, 1}, float64(5), false}, + {Range(Float64, 0, 3), Range(Float64, 0, 3), Shape{1, 3}, Shape{3, 1}, float64(5), false}, + {Range(Float64, 0, 3), Range(Float64, 0, 3), Shape{3, 1}, Shape{1, 3}, float64(5), false}, + {Range(Float64, 0, 3), Range(Float64, 0, 3), Shape{1, 3}, Shape{1, 3}, float64(5), false}, + + {Range(Float32, 0, 3), Range(Float32, 0, 3), Shape{3}, Shape{3}, float32(5), false}, + {Range(Float32, 0, 3), Range(Float32, 0, 3), Shape{3, 1}, Shape{3}, float32(5), false}, + {Range(Float32, 0, 3), Range(Float32, 0, 3), Shape{1, 3}, Shape{3}, float32(5), false}, + {Range(Float32, 0, 3), Range(Float32, 0, 3), Shape{3, 1}, Shape{3, 1}, float32(5), false}, + {Range(Float32, 0, 3), Range(Float32, 0, 3), Shape{1, 3}, Shape{3, 1}, float32(5), false}, + {Range(Float32, 0, 3), Range(Float32, 0, 3), Shape{3, 1}, Shape{1, 3}, float32(5), false}, + {Range(Float32, 0, 3), Range(Float32, 0, 3), Shape{1, 3}, Shape{1, 3}, float32(5), false}, + + // stupids: type differences + {Range(Int, 0, 3), Range(Float64, 0, 3), Shape{3}, Shape{3}, nil, true}, + {Range(Float32, 0, 3), Range(Byte, 0, 3), Shape{3}, Shape{3}, nil, true}, + {Range(Float64, 0, 3), Range(Float32, 0, 3), Shape{3}, Shape{3}, nil, true}, + {Range(Float32, 0, 3), Range(Float64, 0, 3), Shape{3}, Shape{3}, nil, true}, + + // differing size + {Range(Float64, 0, 4), Range(Float64, 0, 3), Shape{4}, Shape{3}, nil, true}, + + // A is not a matrix + {Range(Float64, 0, 4), Range(Float64, 0, 3), Shape{2, 2}, Shape{3}, nil, true}, +} + +func TestColMajor_Dense_Inner(t *testing.T) { + for i, its := range colMajorInnerTests { + a := New(WithShape(its.shapeA...), AsFortran(its.a)) + b := New(WithShape(its.shapeB...), AsFortran(its.b)) + + T, err := a.Inner(b) + if checkErr(t, its.err, err, "Inner", i) { + continue + } + + assert.Equal(t, its.correct, T) + } +} + +var colMajorMatVecMulTests = []linalgTest{ + // Float64s + {Range(Float64, 0, 6), Range(Float64, 0, 3), Shape{2, 3}, Shape{3}, false, false, + Range(Float64, 52, 54), Range(Float64, 100, 102), Shape{2}, Shape{2}, + []float64{5, 14}, []float64{105, 115}, []float64{110, 129}, Shape{2}, false, false, false}, + {Range(Float64, 0, 6), Range(Float64, 0, 3), Shape{2, 3}, Shape{3, 1}, false, false, + Range(Float64, 52, 54), Range(Float64, 100, 102), Shape{2}, Shape{2}, + []float64{5, 14}, []float64{105, 115}, []float64{110, 129}, Shape{2}, false, false, false}, + {Range(Float64, 0, 6), Range(Float64, 0, 3), Shape{2, 3}, Shape{1, 3}, false, false, + Range(Float64, 52, 54), Range(Float64, 100, 102), Shape{2}, Shape{2}, + []float64{5, 14}, []float64{105, 115}, []float64{110, 129}, Shape{2}, false, false, false}, + + // float64s with transposed matrix + {Range(Float64, 0, 6), Range(Float64, 0, 2), Shape{2, 3}, Shape{2}, true, false, + Range(Float64, 52, 55), Range(Float64, 100, 103), Shape{3}, Shape{3}, + []float64{3, 4, 5}, []float64{103, 105, 107}, []float64{106, 109, 112}, Shape{3}, false, false, false}, + + // Float32s + {Range(Float32, 0, 6), Range(Float32, 0, 3), Shape{2, 3}, Shape{3}, false, false, + Range(Float32, 52, 54), Range(Float32, 100, 102), Shape{2}, Shape{2}, + []float32{5, 14}, []float32{105, 115}, []float32{110, 129}, Shape{2}, false, false, false}, + {Range(Float32, 0, 6), Range(Float32, 0, 3), Shape{2, 3}, Shape{3, 1}, false, false, + Range(Float32, 52, 54), Range(Float32, 100, 102), Shape{2}, Shape{2}, + []float32{5, 14}, []float32{105, 115}, []float32{110, 129}, Shape{2}, false, false, false}, + {Range(Float32, 0, 6), Range(Float32, 0, 3), Shape{2, 3}, Shape{1, 3}, false, false, + Range(Float32, 52, 54), Range(Float32, 100, 102), Shape{2}, Shape{2}, + []float32{5, 14}, []float32{105, 115}, []float32{110, 129}, Shape{2}, false, false, false}, + + // stupids : unpossible shapes (wrong A) + {Range(Float64, 0, 6), Range(Float64, 0, 3), Shape{6}, Shape{3}, false, false, + Range(Float64, 52, 54), Range(Float64, 100, 102), Shape{2}, Shape{2}, + []float64{5, 14}, []float64{105, 115}, []float64{110, 129}, Shape{2}, true, false, false}, + + //stupids: bad A shape + {Range(Float64, 0, 8), Range(Float64, 0, 3), Shape{4, 2}, Shape{3}, false, false, + Range(Float64, 52, 54), Range(Float64, 100, 102), Shape{2}, Shape{2}, + []float64{5, 14}, []float64{105, 115}, []float64{110, 129}, Shape{2}, true, false, false}, + + //stupids: bad B shape + {Range(Float64, 0, 6), Range(Float64, 0, 6), Shape{2, 3}, Shape{3, 2}, false, false, + Range(Float64, 52, 54), Range(Float64, 100, 102), Shape{2}, Shape{2}, + []float64{5, 14}, []float64{105, 115}, []float64{110, 129}, Shape{2}, true, false, false}, + + //stupids: bad reuse + {Range(Float64, 0, 6), Range(Float64, 0, 3), Shape{2, 3}, Shape{3}, false, false, + Range(Float64, 52, 55), Range(Float64, 100, 102), Shape{3}, Shape{2}, + []float64{5, 14}, []float64{105, 115}, []float64{110, 129}, Shape{2}, false, false, true}, + + //stupids: bad incr shape + {Range(Float64, 0, 6), Range(Float64, 0, 3), Shape{2, 3}, Shape{3}, false, false, + Range(Float64, 52, 54), Range(Float64, 100, 105), Shape{2}, Shape{5}, + []float64{5, 14}, []float64{105, 115}, []float64{110, 129}, Shape{2}, false, true, false}, + + // stupids: type mismatch A and B + {Range(Float64, 0, 6), Range(Float32, 0, 3), Shape{2, 3}, Shape{3}, false, false, + Range(Float64, 52, 54), Range(Float64, 100, 103), Shape{2}, Shape{3}, + []float64{5, 14}, []float64{105, 115}, []float64{110, 129}, Shape{2}, true, false, false}, + + // stupids: type mismatch A and B + {Range(Float32, 0, 6), Range(Float64, 0, 3), Shape{2, 3}, Shape{3}, false, false, + Range(Float64, 52, 54), Range(Float64, 100, 103), Shape{2}, Shape{3}, + []float64{5, 14}, []float64{105, 115}, []float64{110, 129}, Shape{2}, true, false, false}, + + // stupids: type mismatch A and B + {Range(Float64, 0, 6), Range(Float32, 0, 3), Shape{2, 3}, Shape{3}, false, false, + Range(Float64, 52, 54), Range(Float64, 100, 103), Shape{2}, Shape{3}, + []float64{5, 14}, []float64{105, 115}, []float64{110, 129}, Shape{2}, true, false, false}, + + // stupids: type mismatch A and B + {Range(Float32, 0, 6), Range(Float64, 0, 3), Shape{2, 3}, Shape{3}, false, false, + Range(Float64, 52, 54), Range(Float64, 100, 103), Shape{2}, Shape{3}, + []float64{5, 14}, []float64{105, 115}, []float64{110, 129}, Shape{2}, true, false, false}, + + // stupids: type mismatch A and B (non-Float) + {Range(Float64, 0, 6), Range(Int, 0, 3), Shape{2, 3}, Shape{3}, false, false, + Range(Float64, 52, 54), Range(Float64, 100, 103), Shape{2}, Shape{3}, + []float64{5, 14}, []float64{105, 115}, []float64{110, 129}, Shape{2}, true, false, false}, + + // stupids: type mismatch, reuse + {Range(Float64, 0, 6), Range(Float64, 0, 3), Shape{2, 3}, Shape{3}, false, false, + Range(Float32, 52, 54), Range(Float64, 100, 102), Shape{2}, Shape{2}, + []float64{5, 14}, []float64{105, 115}, []float64{110, 129}, Shape{2}, false, false, true}, + + // stupids: type mismatch, incr + {Range(Float64, 0, 6), Range(Float64, 0, 3), Shape{2, 3}, Shape{3}, false, false, + Range(Float64, 52, 54), Range(Float32, 100, 103), Shape{2}, Shape{3}, + []float64{5, 14}, []float64{105, 115}, []float64{110, 129}, Shape{2}, false, true, false}, + + // stupids: type mismatch, incr not a Number + {Range(Float64, 0, 6), Range(Float64, 0, 3), Shape{2, 3}, Shape{3}, false, false, + Range(Float64, 52, 54), []bool{true, true, true}, Shape{2}, Shape{3}, + []float64{5, 14}, []float64{105, 115}, []float64{110, 129}, Shape{2}, false, true, false}, +} + +func TestColMajor_Dense_MatVecMul(t *testing.T) { + assert := assert.New(t) + for i, mvmt := range colMajorMatVecMulTests { + a := New(WithShape(mvmt.shapeA...), AsFortran(mvmt.a)) + b := New(WithShape(mvmt.shapeB...), AsFortran(mvmt.b)) + + if mvmt.transA { + if err := a.T(); err != nil { + t.Error(err) + continue + } + } + + T, err := a.MatVecMul(b) + if checkErr(t, mvmt.err, err, "Safe", i) { + continue + } + + assert.True(mvmt.correctShape.Eq(T.Shape())) + assert.True(T.DataOrder().IsColMajor()) + assert.Equal(mvmt.correct, T.Data()) + + // incr + incr := New(WithShape(mvmt.shapeI...), AsFortran(mvmt.incr)) + T, err = a.MatVecMul(b, WithIncr(incr)) + if checkErr(t, mvmt.errIncr, err, "WithIncr", i) { + continue + } + + assert.True(mvmt.correctShape.Eq(T.Shape())) + assert.True(T.DataOrder().IsColMajor()) + assert.Equal(mvmt.correctIncr, T.Data()) + + // reuse + reuse := New(WithShape(mvmt.shapeR...), AsFortran(mvmt.reuse)) + T, err = a.MatVecMul(b, WithReuse(reuse)) + if checkErr(t, mvmt.errReuse, err, "WithReuse", i) { + continue + } + + assert.True(mvmt.correctShape.Eq(T.Shape())) + assert.True(T.DataOrder().IsColMajor()) + assert.Equal(mvmt.correct, T.Data()) + + // reuse AND incr + T, err = a.MatVecMul(b, WithIncr(incr), WithReuse(reuse)) + if checkErr(t, mvmt.err, err, "WithReuse and WithIncr", i) { + continue + } + assert.True(mvmt.correctShape.Eq(T.Shape())) + assert.True(T.DataOrder().IsColMajor()) + assert.Equal(mvmt.correctIncrReuse, T.Data()) + } +} + +var colMajorMatMulTests = []linalgTest{ + // Float64s + {Range(Float64, 0, 6), Range(Float64, 0, 6), Shape{2, 3}, Shape{3, 2}, false, false, + Range(Float64, 52, 56), Range(Float64, 100, 104), Shape{2, 2}, Shape{2, 2}, + []float64{10, 28, 13, 40}, []float64{110, 130, 114, 143}, []float64{120, 158, 127, 183}, Shape{2, 2}, false, false, false}, + + // Float32s + {Range(Float32, 0, 6), Range(Float32, 0, 6), Shape{2, 3}, Shape{3, 2}, false, false, + Range(Float32, 52, 56), Range(Float32, 100, 104), Shape{2, 2}, Shape{2, 2}, + []float32{10, 28, 13, 40}, []float32{110, 130, 114, 143}, []float32{120, 158, 127, 183}, Shape{2, 2}, false, false, false}, + + // Edge cases - Row Vecs (Float64) + {Range(Float64, 0, 2), Range(Float64, 0, 3), Shape{2, 1}, Shape{1, 3}, false, false, + Range(Float64, 10, 16), Range(Float64, 100, 106), Shape{2, 3}, Shape{2, 3}, + []float64{0, 0, 0, 1, 0, 2}, []float64{100, 103, 101, 105, 102, 107}, []float64{100, 103, 101, 106, 102, 109}, Shape{2, 3}, false, false, false}, + {Range(Float64, 0, 2), Range(Float64, 0, 6), Shape{1, 2}, Shape{2, 3}, false, false, + Range(Float64, 10, 13), Range(Float64, 100, 103), Shape{1, 3}, Shape{1, 3}, + []float64{3, 4, 5}, []float64{103, 105, 107}, []float64{106, 109, 112}, Shape{1, 3}, false, false, false}, + {Range(Float64, 0, 2), Range(Float64, 0, 2), Shape{1, 2}, Shape{2, 1}, false, false, + Range(Float64, 0, 1), Range(Float64, 100, 101), Shape{1, 1}, Shape{1, 1}, + []float64{1}, []float64{101}, []float64{102}, Shape{1, 1}, false, false, false}, + + // Edge cases - Row Vecs (Float32) + {Range(Float32, 0, 2), Range(Float32, 0, 3), Shape{2, 1}, Shape{1, 3}, false, false, + Range(Float32, 10, 16), Range(Float32, 100, 106), Shape{2, 3}, Shape{2, 3}, + []float32{0, 0, 0, 1, 0, 2}, []float32{100, 103, 101, 105, 102, 107}, []float32{100, 103, 101, 106, 102, 109}, Shape{2, 3}, false, false, false}, + {Range(Float32, 0, 2), Range(Float32, 0, 6), Shape{1, 2}, Shape{2, 3}, false, false, + Range(Float32, 10, 13), Range(Float32, 100, 103), Shape{1, 3}, Shape{1, 3}, + []float32{3, 4, 5}, []float32{103, 105, 107}, []float32{106, 109, 112}, Shape{1, 3}, false, false, false}, + {Range(Float32, 0, 2), Range(Float32, 0, 2), Shape{1, 2}, Shape{2, 1}, false, false, + Range(Float32, 0, 1), Range(Float32, 100, 101), Shape{1, 1}, Shape{1, 1}, + []float32{1}, []float32{101}, []float32{102}, Shape{1, 1}, false, false, false}, + + // stupids - bad shape (not matrices): + {Range(Float64, 0, 6), Range(Float64, 0, 6), Shape{2, 3}, Shape{6}, false, false, + Range(Float64, 52, 56), Range(Float64, 100, 104), Shape{2, 2}, Shape{2, 2}, + []float64{10, 28, 13, 40}, []float64{110, 130, 114, 143}, []float64{120, 158, 127, 183}, Shape{2, 2}, true, false, false}, + + // stupids - bad shape (incompatible shapes): + {Range(Float64, 0, 6), Range(Float64, 0, 6), Shape{2, 3}, Shape{6, 1}, false, false, + Range(Float64, 52, 56), Range(Float64, 100, 104), Shape{2, 2}, Shape{2, 2}, + []float64{10, 28, 13, 40}, []float64{110, 130, 114, 143}, []float64{120, 158, 127, 183}, Shape{2, 2}, true, false, false}, + + // stupids - bad shape (bad reuse shape): + {Range(Float64, 0, 6), Range(Float64, 0, 6), Shape{2, 3}, Shape{3, 2}, false, false, + Range(Float64, 52, 57), Range(Float64, 100, 104), Shape{5}, Shape{2, 2}, + []float64{10, 28, 13, 40}, []float64{110, 130, 114, 143}, []float64{120, 158, 127, 183}, Shape{2, 2}, false, false, true}, + + // stupids - bad shape (bad incr shape): + {Range(Float64, 0, 6), Range(Float64, 0, 6), Shape{2, 3}, Shape{3, 2}, false, false, + Range(Float64, 52, 56), Range(Float64, 100, 104), Shape{2, 2}, Shape{4}, + []float64{10, 28, 13, 40}, []float64{110, 130, 114, 143}, []float64{120, 158, 127, 183}, Shape{2, 2}, false, true, false}, + + // stupids - type mismatch (a,b) + {Range(Float64, 0, 6), Range(Float32, 0, 6), Shape{2, 3}, Shape{3, 2}, false, false, + Range(Float64, 52, 56), Range(Float64, 100, 104), Shape{2, 2}, Shape{2, 2}, + []float64{10, 28, 13, 40}, []float64{110, 130, 114, 143}, []float64{120, 158, 127, 183}, Shape{2, 2}, true, false, false}, + + // stupids - type mismatch (a,b) + {Range(Float32, 0, 6), Range(Float64, 0, 6), Shape{2, 3}, Shape{3, 2}, false, false, + Range(Float64, 52, 56), Range(Float64, 100, 104), Shape{2, 2}, Shape{2, 2}, + []float64{10, 28, 13, 40}, []float64{110, 130, 114, 143}, []float64{120, 158, 127, 183}, Shape{2, 2}, true, false, false}, + + // stupids type mismatch (b not float) + {Range(Float64, 0, 6), Range(Int, 0, 6), Shape{2, 3}, Shape{3, 2}, false, false, + Range(Float64, 52, 56), Range(Float64, 100, 104), Shape{2, 2}, Shape{2, 2}, + []float64{10, 28, 13, 40}, []float64{110, 130, 114, 143}, []float64{120, 158, 127, 183}, Shape{2, 2}, true, false, false}, + + // stupids type mismatch (a not float) + {Range(Int, 0, 6), Range(Int, 0, 6), Shape{2, 3}, Shape{3, 2}, false, false, + Range(Float64, 52, 56), Range(Float64, 100, 104), Shape{2, 2}, Shape{2, 2}, + []float64{10, 28, 13, 40}, []float64{110, 130, 114, 143}, []float64{120, 158, 127, 183}, Shape{2, 2}, true, false, false}, + + // stupids: type mismatch (incr) + {Range(Float64, 0, 6), Range(Float64, 0, 6), Shape{2, 3}, Shape{3, 2}, false, false, + Range(Float64, 52, 56), Range(Float32, 100, 104), Shape{2, 2}, Shape{2, 2}, + []float64{10, 28, 13, 40}, []float64{110, 130, 114, 143}, []float64{120, 158, 127, 183}, Shape{2, 2}, false, true, false}, + + // stupids: type mismatch (reuse) + {Range(Float64, 0, 6), Range(Float64, 0, 6), Shape{2, 3}, Shape{3, 2}, false, false, + Range(Float32, 52, 56), Range(Float64, 100, 104), Shape{2, 2}, Shape{2, 2}, + []float64{10, 28, 13, 40}, []float64{110, 130, 114, 143}, []float64{120, 158, 127, 183}, Shape{2, 2}, false, false, true}, + + // stupids: type mismatch (reuse) + {Range(Float32, 0, 6), Range(Float32, 0, 6), Shape{2, 3}, Shape{3, 2}, false, false, + Range(Float64, 52, 56), Range(Float32, 100, 104), Shape{2, 2}, Shape{2, 2}, + []float32{10, 28, 13, 40}, []float32{110, 130, 114, 143}, []float32{120, 158, 127, 183}, Shape{2, 2}, false, false, true}, +} + +func TestColMajorDense_MatMul(t *testing.T) { + assert := assert.New(t) + for i, mmt := range colMajorMatMulTests { + a := New(WithShape(mmt.shapeA...), AsFortran(mmt.a)) + b := New(WithShape(mmt.shapeB...), AsFortran(mmt.b)) + + T, err := a.MatMul(b) + if checkErr(t, mmt.err, err, "Safe", i) { + continue + } + assert.True(mmt.correctShape.Eq(T.Shape())) + assert.True(T.DataOrder().IsColMajor()) + assert.Equal(mmt.correct, T.Data(), "Test %d", i) + + // incr + incr := New(WithShape(mmt.shapeI...), AsFortran(mmt.incr)) + T, err = a.MatMul(b, WithIncr(incr)) + if checkErr(t, mmt.errIncr, err, "WithIncr", i) { + continue + } + assert.True(mmt.correctShape.Eq(T.Shape())) + assert.Equal(mmt.correctIncr, T.Data()) + + // reuse + reuse := New(WithShape(mmt.shapeR...), AsFortran(mmt.reuse)) + T, err = a.MatMul(b, WithReuse(reuse)) + + if checkErr(t, mmt.errReuse, err, "WithReuse", i) { + continue + } + assert.True(mmt.correctShape.Eq(T.Shape())) + assert.Equal(mmt.correct, T.Data()) + + // reuse AND incr + T, err = a.MatMul(b, WithIncr(incr), WithReuse(reuse)) + if checkErr(t, mmt.err, err, "WithIncr and WithReuse", i) { + continue + } + assert.True(mmt.correctShape.Eq(T.Shape())) + assert.Equal(mmt.correctIncrReuse, T.Data()) + } +} + +var colMajorOuterTests = []linalgTest{ + // Float64s + {Range(Float64, 0, 3), Range(Float64, 0, 3), Shape{3}, Shape{3}, false, false, + Range(Float64, 52, 61), Range(Float64, 100, 109), Shape{3, 3}, Shape{3, 3}, + []float64{0, 0, 0, 0, 1, 2, 0, 2, 4}, []float64{100, 103, 106, 101, 105, 109, 102, 107, 112}, []float64{100, 103, 106, 101, 106, 111, 102, 109, 116}, Shape{3, 3}, + false, false, false}, + + // Float32s + {Range(Float32, 0, 3), Range(Float32, 0, 3), Shape{3}, Shape{3}, false, false, + Range(Float32, 52, 61), Range(Float32, 100, 109), Shape{3, 3}, Shape{3, 3}, + []float32{0, 0, 0, 0, 1, 2, 0, 2, 4}, []float32{100, 103, 106, 101, 105, 109, 102, 107, 112}, []float32{100, 103, 106, 101, 106, 111, 102, 109, 116}, Shape{3, 3}, + false, false, false}, + + // stupids - a or b not vector + {Range(Float64, 0, 3), Range(Float64, 0, 6), Shape{3}, Shape{3, 2}, false, false, + Range(Float64, 52, 61), Range(Float64, 100, 109), Shape{3, 3}, Shape{3, 3}, + []float64{0, 0, 0, 0, 1, 2, 0, 2, 4}, []float64{100, 103, 106, 101, 105, 109, 102, 107, 112}, []float64{100, 103, 106, 101, 106, 111, 102, 109, 116}, Shape{3, 3}, + true, false, false}, + + // stupids - bad incr shape + {Range(Float64, 0, 3), Range(Float64, 0, 3), Shape{3}, Shape{3}, false, false, + Range(Float64, 52, 61), Range(Float64, 100, 106), Shape{3, 3}, Shape{3, 2}, + []float64{0, 0, 0, 0, 1, 2, 0, 2, 4}, []float64{100, 103, 106, 101, 105, 109, 102, 107, 112}, []float64{100, 103, 106, 101, 106, 111, 102, 109, 116}, Shape{3, 3}, + false, true, false}, + + // stupids - bad reuse shape + {Range(Float64, 0, 3), Range(Float64, 0, 3), Shape{3}, Shape{3}, false, false, + Range(Float64, 52, 58), Range(Float64, 100, 109), Shape{3, 2}, Shape{3, 3}, + []float64{0, 0, 0, 0, 1, 2, 0, 2, 4}, []float64{100, 103, 106, 101, 105, 109, 102, 107, 112}, []float64{100, 103, 106, 101, 106, 111, 102, 109, 116}, Shape{3, 3}, + false, false, true}, + + // stupids - b not Float + {Range(Float64, 0, 3), Range(Int, 0, 3), Shape{3}, Shape{3}, false, false, + Range(Float64, 52, 61), Range(Float64, 100, 109), Shape{3, 3}, Shape{3, 3}, + []float64{0, 0, 0, 0, 1, 2, 0, 2, 4}, []float64{100, 103, 106, 101, 105, 109, 102, 107, 112}, []float64{100, 103, 106, 101, 106, 111, 102, 109, 116}, Shape{3, 3}, + true, false, false}, + + // stupids - a not Float + {Range(Int, 0, 3), Range(Float64, 0, 3), Shape{3}, Shape{3}, false, false, + Range(Float64, 52, 61), Range(Float64, 100, 109), Shape{3, 3}, Shape{3, 3}, + []float64{0, 0, 0, 0, 1, 2, 0, 2, 4}, []float64{100, 103, 106, 101, 105, 109, 102, 107, 112}, []float64{100, 103, 106, 101, 106, 111, 102, 109, 116}, Shape{3, 3}, + true, false, false}, + + // stupids - a-b type mismatch + {Range(Float64, 0, 3), Range(Float32, 0, 3), Shape{3}, Shape{3}, false, false, + Range(Float64, 52, 61), Range(Float64, 100, 109), Shape{3, 3}, Shape{3, 3}, + []float64{0, 0, 0, 0, 1, 2, 0, 2, 4}, []float64{100, 103, 106, 101, 105, 109, 102, 107, 112}, []float64{100, 103, 106, 101, 106, 111, 102, 109, 116}, Shape{3, 3}, + true, false, false}, + + // stupids a-b type mismatch + {Range(Float32, 0, 3), Range(Float64, 0, 3), Shape{3}, Shape{3}, false, false, + Range(Float64, 52, 61), Range(Float64, 100, 109), Shape{3, 3}, Shape{3, 3}, + []float64{0, 0, 0, 0, 1, 2, 0, 2, 4}, []float64{100, 103, 106, 101, 105, 109, 102, 107, 112}, []float64{100, 103, 106, 101, 106, 111, 102, 109, 116}, Shape{3, 3}, + true, false, false}, +} + +func TestColMajor_Dense_Outer(t *testing.T) { + assert := assert.New(t) + for i, ot := range colMajorOuterTests { + a := New(WithShape(ot.shapeA...), AsFortran(ot.a)) + b := New(WithShape(ot.shapeB...), AsFortran(ot.b)) + + T, err := a.Outer(b) + if checkErr(t, ot.err, err, "Safe", i) { + continue + } + assert.True(ot.correctShape.Eq(T.Shape())) + assert.True(T.DataOrder().IsColMajor()) + assert.Equal(ot.correct, T.Data()) + + // incr + incr := New(WithShape(ot.shapeI...), AsFortran(ot.incr)) + T, err = a.Outer(b, WithIncr(incr)) + if checkErr(t, ot.errIncr, err, "WithIncr", i) { + continue + } + assert.True(ot.correctShape.Eq(T.Shape())) + assert.True(T.DataOrder().IsColMajor()) + assert.Equal(ot.correctIncr, T.Data()) + + // reuse + reuse := New(WithShape(ot.shapeR...), AsFortran(ot.reuse)) + T, err = a.Outer(b, WithReuse(reuse)) + if checkErr(t, ot.errReuse, err, "WithReuse", i) { + continue + } + assert.True(ot.correctShape.Eq(T.Shape())) + assert.True(T.DataOrder().IsColMajor()) + assert.Equal(ot.correct, T.Data()) + + // reuse AND incr + T, err = a.Outer(b, WithIncr(incr), WithReuse(reuse)) + if err != nil { + t.Errorf("Reuse and Incr error'd %+v", err) + continue + } + assert.True(ot.correctShape.Eq(T.Shape())) + assert.True(T.DataOrder().IsColMajor()) + assert.Equal(ot.correctIncrReuse, T.Data()) + } +} diff --git a/dense_compat.go b/dense_compat.go index 151ae0a..cf6764d 100644 --- a/dense_compat.go +++ b/dense_compat.go @@ -1,19 +1,24 @@ -// Code generated by genlib2. DO NOT EDIT. - package tensor +// Code generated by genlib2. DO NOT EDIT. + import ( "fmt" "math" "math/cmplx" "reflect" + arrow "github.com/apache/arrow/go/arrow" + arrowArray "github.com/apache/arrow/go/arrow/array" + "github.com/apache/arrow/go/arrow/bitutil" + arrowTensor "github.com/apache/arrow/go/arrow/tensor" "github.com/chewxy/math32" "github.com/pkg/errors" "gonum.org/v1/gonum/mat" + "gorgonia.org/dtype" ) -func convFromFloat64s(to Dtype, data []float64) interface{} { +func convFromFloat64s(to dtype.Dtype, data []float64) interface{} { switch to { case Int: retVal := make([]int, len(data)) @@ -394,14 +399,12 @@ func FromMat64(m *mat.Dense, opts ...FuncOpt) *Dense { func ToMat64(t *Dense, opts ...FuncOpt) (retVal *mat.Dense, err error) { // checks: if !t.IsNativelyAccessible() { - err = errors.Errorf("Cannot convert *Dense to *mat.Dense. Data is inaccessible") - return + return nil, errors.Errorf("Cannot convert *Dense to *mat.Dense. Data is inaccessible") } if !t.IsMatrix() { // error - err = errors.Errorf("Cannot convert *Dense to *mat.Dense. Expected number of dimensions: <=2, T has got %d dimensions (Shape: %v)", t.Dims(), t.Shape()) - return + return nil, errors.Errorf("Cannot convert *Dense to *mat.Dense. Expected number of dimensions: <=2, T has got %d dimensions (Shape: %v)", t.Dims(), t.Shape()) } fo := ParseFuncOpts(opts...) @@ -414,13 +417,13 @@ func ToMat64(t *Dense, opts ...FuncOpt) (retVal *mat.Dense, err error) { var data []float64 switch { - case t.t == Float64 && toCopy && !t.IsMaterializable(): + case t.t == Float64 && toCopy && !t.RequiresIterator() && t.viewOf == 0: data = make([]float64, t.len()) copy(data, t.Float64s()) - case !t.IsMaterializable(): + case !t.RequiresIterator() && t.viewOf == 0: data = convToFloat64s(t) default: - it := NewFlatIterator(t.AP) + it := newFlatIterator(&t.AP) var next int for next, err = it.Next(); err == nil; next, err = it.Next() { if err = handleNoOp(err); err != nil { @@ -435,3 +438,179 @@ func ToMat64(t *Dense, opts ...FuncOpt) (retVal *mat.Dense, err error) { retVal = mat.NewDense(r, c, data) return } + +// FromArrowArray converts an "arrow/array".Interface into a Tensor of matching DataType. +func FromArrowArray(a arrowArray.Interface) *Dense { + a.Retain() + defer a.Release() + + r := a.Len() + + // TODO(poopoothegorilla): instead of creating bool ValidMask maybe + // bitmapBytes can be used from arrow API + mask := make([]bool, r) + for i := 0; i < r; i++ { + mask[i] = a.IsNull(i) + } + + switch a.DataType() { + case arrow.BinaryTypes.String: + backing := make([]string, r) + for i := 0; i < r; i++ { + backing[i] = a.(*arrowArray.String).Value(i) + } + retVal := New(WithBacking(backing, mask), WithShape(r, 1)) + return retVal + case arrow.FixedWidthTypes.Boolean: + backing := make([]bool, r) + for i := 0; i < r; i++ { + backing[i] = a.(*arrowArray.Boolean).Value(i) + } + retVal := New(WithBacking(backing, mask), WithShape(r, 1)) + return retVal + case arrow.PrimitiveTypes.Int8: + backing := a.(*arrowArray.Int8).Int8Values() + retVal := New(WithBacking(backing, mask), WithShape(r, 1)) + return retVal + case arrow.PrimitiveTypes.Int16: + backing := a.(*arrowArray.Int16).Int16Values() + retVal := New(WithBacking(backing, mask), WithShape(r, 1)) + return retVal + case arrow.PrimitiveTypes.Int32: + backing := a.(*arrowArray.Int32).Int32Values() + retVal := New(WithBacking(backing, mask), WithShape(r, 1)) + return retVal + case arrow.PrimitiveTypes.Int64: + backing := a.(*arrowArray.Int64).Int64Values() + retVal := New(WithBacking(backing, mask), WithShape(r, 1)) + return retVal + case arrow.PrimitiveTypes.Uint8: + backing := a.(*arrowArray.Uint8).Uint8Values() + retVal := New(WithBacking(backing, mask), WithShape(r, 1)) + return retVal + case arrow.PrimitiveTypes.Uint16: + backing := a.(*arrowArray.Uint16).Uint16Values() + retVal := New(WithBacking(backing, mask), WithShape(r, 1)) + return retVal + case arrow.PrimitiveTypes.Uint32: + backing := a.(*arrowArray.Uint32).Uint32Values() + retVal := New(WithBacking(backing, mask), WithShape(r, 1)) + return retVal + case arrow.PrimitiveTypes.Uint64: + backing := a.(*arrowArray.Uint64).Uint64Values() + retVal := New(WithBacking(backing, mask), WithShape(r, 1)) + return retVal + case arrow.PrimitiveTypes.Float32: + backing := a.(*arrowArray.Float32).Float32Values() + retVal := New(WithBacking(backing, mask), WithShape(r, 1)) + return retVal + case arrow.PrimitiveTypes.Float64: + backing := a.(*arrowArray.Float64).Float64Values() + retVal := New(WithBacking(backing, mask), WithShape(r, 1)) + return retVal + default: + panic(fmt.Sprintf("Unsupported Arrow DataType - %v", a.DataType())) + } + + panic("Unreachable") +} + +// FromArrowTensor converts an "arrow/tensor".Interface into a Tensor of matching DataType. +func FromArrowTensor(a arrowTensor.Interface) *Dense { + a.Retain() + defer a.Release() + + if !a.IsContiguous() { + panic("Non-contiguous data is Unsupported") + } + + var shape []int + for _, val := range a.Shape() { + shape = append(shape, int(val)) + } + + l := a.Len() + validMask := a.Data().Buffers()[0].Bytes() + dataOffset := a.Data().Offset() + mask := make([]bool, l) + for i := 0; i < l; i++ { + mask[i] = len(validMask) != 0 && bitutil.BitIsNotSet(validMask, dataOffset+i) + } + + switch a.DataType() { + case arrow.PrimitiveTypes.Int8: + backing := a.(*arrowTensor.Int8).Int8Values() + if a.IsColMajor() { + return New(WithShape(shape...), AsFortran(backing, mask)) + } + + return New(WithShape(shape...), WithBacking(backing, mask)) + case arrow.PrimitiveTypes.Int16: + backing := a.(*arrowTensor.Int16).Int16Values() + if a.IsColMajor() { + return New(WithShape(shape...), AsFortran(backing, mask)) + } + + return New(WithShape(shape...), WithBacking(backing, mask)) + case arrow.PrimitiveTypes.Int32: + backing := a.(*arrowTensor.Int32).Int32Values() + if a.IsColMajor() { + return New(WithShape(shape...), AsFortran(backing, mask)) + } + + return New(WithShape(shape...), WithBacking(backing, mask)) + case arrow.PrimitiveTypes.Int64: + backing := a.(*arrowTensor.Int64).Int64Values() + if a.IsColMajor() { + return New(WithShape(shape...), AsFortran(backing, mask)) + } + + return New(WithShape(shape...), WithBacking(backing, mask)) + case arrow.PrimitiveTypes.Uint8: + backing := a.(*arrowTensor.Uint8).Uint8Values() + if a.IsColMajor() { + return New(WithShape(shape...), AsFortran(backing, mask)) + } + + return New(WithShape(shape...), WithBacking(backing, mask)) + case arrow.PrimitiveTypes.Uint16: + backing := a.(*arrowTensor.Uint16).Uint16Values() + if a.IsColMajor() { + return New(WithShape(shape...), AsFortran(backing, mask)) + } + + return New(WithShape(shape...), WithBacking(backing, mask)) + case arrow.PrimitiveTypes.Uint32: + backing := a.(*arrowTensor.Uint32).Uint32Values() + if a.IsColMajor() { + return New(WithShape(shape...), AsFortran(backing, mask)) + } + + return New(WithShape(shape...), WithBacking(backing, mask)) + case arrow.PrimitiveTypes.Uint64: + backing := a.(*arrowTensor.Uint64).Uint64Values() + if a.IsColMajor() { + return New(WithShape(shape...), AsFortran(backing, mask)) + } + + return New(WithShape(shape...), WithBacking(backing, mask)) + case arrow.PrimitiveTypes.Float32: + backing := a.(*arrowTensor.Float32).Float32Values() + if a.IsColMajor() { + return New(WithShape(shape...), AsFortran(backing, mask)) + } + + return New(WithShape(shape...), WithBacking(backing, mask)) + case arrow.PrimitiveTypes.Float64: + backing := a.(*arrowTensor.Float64).Float64Values() + if a.IsColMajor() { + return New(WithShape(shape...), AsFortran(backing, mask)) + } + + return New(WithShape(shape...), WithBacking(backing, mask)) + default: + panic(fmt.Sprintf("Unsupported Arrow DataType - %v", a.DataType())) + } + + panic("Unreachable") +} diff --git a/dense_compat_test.go b/dense_compat_test.go index 494fc25..442b7e7 100644 --- a/dense_compat_test.go +++ b/dense_compat_test.go @@ -1,19 +1,24 @@ -// Code generated by genlib2. DO NOT EDIT. - package tensor +// Code generated by genlib2. DO NOT EDIT. + import ( "testing" + arrow "github.com/apache/arrow/go/arrow" + arrowArray "github.com/apache/arrow/go/arrow/array" + "github.com/apache/arrow/go/arrow/memory" + arrowTensor "github.com/apache/arrow/go/arrow/tensor" "github.com/stretchr/testify/assert" "gonum.org/v1/gonum/mat" + "gorgonia.org/dtype" ) var toMat64Tests = []struct { data interface{} sliced interface{} shape Shape - dt Dtype + dt dtype.Dtype }{ {Range(Int, 0, 6), []int{0, 1, 3, 4}, Shape{2, 3}, Int}, {Range(Int8, 0, 6), []int8{0, 1, 3, 4}, Shape{2, 3}, Int8}, @@ -105,3 +110,570 @@ func TestFromMat64(t *testing.T) { } } } + +var toArrowArrayTests = []struct { + data interface{} + valid []bool + dt arrow.DataType + shape Shape +}{ + { + data: Range(Int8, 0, 6), + valid: []bool{true, true, true, false, true, true}, + dt: arrow.PrimitiveTypes.Int8, + shape: Shape{6, 1}, + }, + { + data: Range(Int16, 0, 6), + valid: []bool{true, true, true, false, true, true}, + dt: arrow.PrimitiveTypes.Int16, + shape: Shape{6, 1}, + }, + { + data: Range(Int32, 0, 6), + valid: []bool{true, true, true, false, true, true}, + dt: arrow.PrimitiveTypes.Int32, + shape: Shape{6, 1}, + }, + { + data: Range(Int64, 0, 6), + valid: []bool{true, true, true, false, true, true}, + dt: arrow.PrimitiveTypes.Int64, + shape: Shape{6, 1}, + }, + { + data: Range(Uint8, 0, 6), + valid: []bool{true, true, true, false, true, true}, + dt: arrow.PrimitiveTypes.Uint8, + shape: Shape{6, 1}, + }, + { + data: Range(Uint16, 0, 6), + valid: []bool{true, true, true, false, true, true}, + dt: arrow.PrimitiveTypes.Uint16, + shape: Shape{6, 1}, + }, + { + data: Range(Uint32, 0, 6), + valid: []bool{true, true, true, false, true, true}, + dt: arrow.PrimitiveTypes.Uint32, + shape: Shape{6, 1}, + }, + { + data: Range(Uint64, 0, 6), + valid: []bool{true, true, true, false, true, true}, + dt: arrow.PrimitiveTypes.Uint64, + shape: Shape{6, 1}, + }, + { + data: Range(Float32, 0, 6), + valid: []bool{true, true, true, false, true, true}, + dt: arrow.PrimitiveTypes.Float32, + shape: Shape{6, 1}, + }, + { + data: Range(Float64, 0, 6), + valid: []bool{true, true, true, false, true, true}, + dt: arrow.PrimitiveTypes.Float64, + shape: Shape{6, 1}, + }, +} + +func TestFromArrowArray(t *testing.T) { + assert := assert.New(t) + var T *Dense + pool := memory.NewGoAllocator() + + for i, taat := range toArrowArrayTests { + var m arrowArray.Interface + + switch taat.dt { + case arrow.BinaryTypes.String: + b := arrowArray.NewStringBuilder(pool) + defer b.Release() + b.AppendValues( + []string{"0", "1", "2", "3", "4", "5"}, + taat.valid, + ) + m = b.NewArray() + defer m.Release() + case arrow.FixedWidthTypes.Boolean: + b := arrowArray.NewBooleanBuilder(pool) + defer b.Release() + b.AppendValues( + []bool{true, false, true, false, true, false}, + taat.valid, + ) + m = b.NewArray() + defer m.Release() + case arrow.PrimitiveTypes.Int8: + b := arrowArray.NewInt8Builder(pool) + defer b.Release() + b.AppendValues( + Range(Int8, 0, 6).([]int8), + taat.valid, + ) + m = b.NewArray() + defer m.Release() + case arrow.PrimitiveTypes.Int16: + b := arrowArray.NewInt16Builder(pool) + defer b.Release() + b.AppendValues( + Range(Int16, 0, 6).([]int16), + taat.valid, + ) + m = b.NewArray() + defer m.Release() + case arrow.PrimitiveTypes.Int32: + b := arrowArray.NewInt32Builder(pool) + defer b.Release() + b.AppendValues( + Range(Int32, 0, 6).([]int32), + taat.valid, + ) + m = b.NewArray() + defer m.Release() + case arrow.PrimitiveTypes.Int64: + b := arrowArray.NewInt64Builder(pool) + defer b.Release() + b.AppendValues( + Range(Int64, 0, 6).([]int64), + taat.valid, + ) + m = b.NewArray() + defer m.Release() + case arrow.PrimitiveTypes.Uint8: + b := arrowArray.NewUint8Builder(pool) + defer b.Release() + b.AppendValues( + Range(Uint8, 0, 6).([]uint8), + taat.valid, + ) + m = b.NewArray() + defer m.Release() + case arrow.PrimitiveTypes.Uint16: + b := arrowArray.NewUint16Builder(pool) + defer b.Release() + b.AppendValues( + Range(Uint16, 0, 6).([]uint16), + taat.valid, + ) + m = b.NewArray() + defer m.Release() + case arrow.PrimitiveTypes.Uint32: + b := arrowArray.NewUint32Builder(pool) + defer b.Release() + b.AppendValues( + Range(Uint32, 0, 6).([]uint32), + taat.valid, + ) + m = b.NewArray() + defer m.Release() + case arrow.PrimitiveTypes.Uint64: + b := arrowArray.NewUint64Builder(pool) + defer b.Release() + b.AppendValues( + Range(Uint64, 0, 6).([]uint64), + taat.valid, + ) + m = b.NewArray() + defer m.Release() + case arrow.PrimitiveTypes.Float32: + b := arrowArray.NewFloat32Builder(pool) + defer b.Release() + b.AppendValues( + Range(Float32, 0, 6).([]float32), + taat.valid, + ) + m = b.NewArray() + defer m.Release() + case arrow.PrimitiveTypes.Float64: + b := arrowArray.NewFloat64Builder(pool) + defer b.Release() + b.AppendValues( + Range(Float64, 0, 6).([]float64), + taat.valid, + ) + m = b.NewArray() + defer m.Release() + default: + t.Errorf("DataType not supported in tests: %v", taat.dt) + } + + T = FromArrowArray(m) + switch taat.dt { + case arrow.PrimitiveTypes.Int8: + conv := taat.data.([]int8) + assert.Equal(conv, T.Int8s(), "test %d: []int8 from %v", i, taat.dt) + case arrow.PrimitiveTypes.Int16: + conv := taat.data.([]int16) + assert.Equal(conv, T.Int16s(), "test %d: []int16 from %v", i, taat.dt) + case arrow.PrimitiveTypes.Int32: + conv := taat.data.([]int32) + assert.Equal(conv, T.Int32s(), "test %d: []int32 from %v", i, taat.dt) + case arrow.PrimitiveTypes.Int64: + conv := taat.data.([]int64) + assert.Equal(conv, T.Int64s(), "test %d: []int64 from %v", i, taat.dt) + case arrow.PrimitiveTypes.Uint8: + conv := taat.data.([]uint8) + assert.Equal(conv, T.Uint8s(), "test %d: []uint8 from %v", i, taat.dt) + case arrow.PrimitiveTypes.Uint16: + conv := taat.data.([]uint16) + assert.Equal(conv, T.Uint16s(), "test %d: []uint16 from %v", i, taat.dt) + case arrow.PrimitiveTypes.Uint32: + conv := taat.data.([]uint32) + assert.Equal(conv, T.Uint32s(), "test %d: []uint32 from %v", i, taat.dt) + case arrow.PrimitiveTypes.Uint64: + conv := taat.data.([]uint64) + assert.Equal(conv, T.Uint64s(), "test %d: []uint64 from %v", i, taat.dt) + case arrow.PrimitiveTypes.Float32: + conv := taat.data.([]float32) + assert.Equal(conv, T.Float32s(), "test %d: []float32 from %v", i, taat.dt) + case arrow.PrimitiveTypes.Float64: + conv := taat.data.([]float64) + assert.Equal(conv, T.Float64s(), "test %d: []float64 from %v", i, taat.dt) + default: + t.Errorf("DataType not supported in tests: %v", taat.dt) + } + for i, invalid := range T.Mask() { + assert.Equal(taat.valid[i], !invalid) + } + assert.True(T.Shape().Eq(taat.shape)) + } +} + +var toArrowTensorTests = []struct { + rowMajorData interface{} + colMajorData interface{} + rowMajorValid []bool + colMajorValid []bool + dt arrow.DataType + shape Shape +}{ + { + rowMajorData: []int8{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + colMajorData: []int8{1, 6, 2, 7, 3, 8, 4, 9, 5, 10}, + rowMajorValid: []bool{true, false, true, false, true, false, true, false, true, false}, + colMajorValid: []bool{true, false, false, true, true, false, false, true, true, false}, + dt: arrow.PrimitiveTypes.Int8, + shape: Shape{2, 5}, + }, + { + rowMajorData: []int16{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + colMajorData: []int16{1, 6, 2, 7, 3, 8, 4, 9, 5, 10}, + rowMajorValid: []bool{true, false, true, false, true, false, true, false, true, false}, + colMajorValid: []bool{true, false, false, true, true, false, false, true, true, false}, + dt: arrow.PrimitiveTypes.Int16, + shape: Shape{2, 5}, + }, + { + rowMajorData: []int32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + colMajorData: []int32{1, 6, 2, 7, 3, 8, 4, 9, 5, 10}, + rowMajorValid: []bool{true, false, true, false, true, false, true, false, true, false}, + colMajorValid: []bool{true, false, false, true, true, false, false, true, true, false}, + dt: arrow.PrimitiveTypes.Int32, + shape: Shape{2, 5}, + }, + { + rowMajorData: []int64{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + colMajorData: []int64{1, 6, 2, 7, 3, 8, 4, 9, 5, 10}, + rowMajorValid: []bool{true, false, true, false, true, false, true, false, true, false}, + colMajorValid: []bool{true, false, false, true, true, false, false, true, true, false}, + dt: arrow.PrimitiveTypes.Int64, + shape: Shape{2, 5}, + }, + { + rowMajorData: []uint8{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + colMajorData: []uint8{1, 6, 2, 7, 3, 8, 4, 9, 5, 10}, + rowMajorValid: []bool{true, false, true, false, true, false, true, false, true, false}, + colMajorValid: []bool{true, false, false, true, true, false, false, true, true, false}, + dt: arrow.PrimitiveTypes.Uint8, + shape: Shape{2, 5}, + }, + { + rowMajorData: []uint16{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + colMajorData: []uint16{1, 6, 2, 7, 3, 8, 4, 9, 5, 10}, + rowMajorValid: []bool{true, false, true, false, true, false, true, false, true, false}, + colMajorValid: []bool{true, false, false, true, true, false, false, true, true, false}, + dt: arrow.PrimitiveTypes.Uint16, + shape: Shape{2, 5}, + }, + { + rowMajorData: []uint32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + colMajorData: []uint32{1, 6, 2, 7, 3, 8, 4, 9, 5, 10}, + rowMajorValid: []bool{true, false, true, false, true, false, true, false, true, false}, + colMajorValid: []bool{true, false, false, true, true, false, false, true, true, false}, + dt: arrow.PrimitiveTypes.Uint32, + shape: Shape{2, 5}, + }, + { + rowMajorData: []uint64{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + colMajorData: []uint64{1, 6, 2, 7, 3, 8, 4, 9, 5, 10}, + rowMajorValid: []bool{true, false, true, false, true, false, true, false, true, false}, + colMajorValid: []bool{true, false, false, true, true, false, false, true, true, false}, + dt: arrow.PrimitiveTypes.Uint64, + shape: Shape{2, 5}, + }, + { + rowMajorData: []float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + colMajorData: []float32{1, 6, 2, 7, 3, 8, 4, 9, 5, 10}, + rowMajorValid: []bool{true, false, true, false, true, false, true, false, true, false}, + colMajorValid: []bool{true, false, false, true, true, false, false, true, true, false}, + dt: arrow.PrimitiveTypes.Float32, + shape: Shape{2, 5}, + }, + { + rowMajorData: []float64{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + colMajorData: []float64{1, 6, 2, 7, 3, 8, 4, 9, 5, 10}, + rowMajorValid: []bool{true, false, true, false, true, false, true, false, true, false}, + colMajorValid: []bool{true, false, false, true, true, false, false, true, true, false}, + dt: arrow.PrimitiveTypes.Float64, + shape: Shape{2, 5}, + }, +} + +func TestFromArrowTensor(t *testing.T) { + assert := assert.New(t) + var rowMajorT *Dense + var colMajorT *Dense + pool := memory.NewGoAllocator() + + for i, taat := range toArrowTensorTests { + var rowMajorArr arrowArray.Interface + var colMajorArr arrowArray.Interface + var rowMajor arrowTensor.Interface + var colMajor arrowTensor.Interface + + switch taat.dt { + case arrow.PrimitiveTypes.Int8: + b := arrowArray.NewInt8Builder(pool) + defer b.Release() + b.AppendValues( + []int8{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + taat.rowMajorValid, + ) + rowMajorArr = b.NewArray() + defer rowMajorArr.Release() + + b.AppendValues( + []int8{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + taat.rowMajorValid, + ) + colMajorArr = b.NewArray() + defer colMajorArr.Release() + + rowMajor = arrowTensor.NewInt8(rowMajorArr.Data(), []int64{2, 5}, nil, []string{"x", "y"}) + defer rowMajor.Release() + colMajor = arrowTensor.NewInt8(colMajorArr.Data(), []int64{2, 5}, []int64{int64(arrow.Int8SizeBytes), int64(arrow.Int8SizeBytes * 2)}, []string{"x", "y"}) + defer colMajor.Release() + case arrow.PrimitiveTypes.Int16: + b := arrowArray.NewInt16Builder(pool) + defer b.Release() + b.AppendValues( + []int16{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + taat.rowMajorValid, + ) + rowMajorArr = b.NewArray() + defer rowMajorArr.Release() + + b.AppendValues( + []int16{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + taat.rowMajorValid, + ) + colMajorArr = b.NewArray() + defer colMajorArr.Release() + + rowMajor = arrowTensor.NewInt16(rowMajorArr.Data(), []int64{2, 5}, nil, []string{"x", "y"}) + defer rowMajor.Release() + colMajor = arrowTensor.NewInt16(colMajorArr.Data(), []int64{2, 5}, []int64{int64(arrow.Int16SizeBytes), int64(arrow.Int16SizeBytes * 2)}, []string{"x", "y"}) + defer colMajor.Release() + case arrow.PrimitiveTypes.Int32: + b := arrowArray.NewInt32Builder(pool) + defer b.Release() + b.AppendValues( + []int32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + taat.rowMajorValid, + ) + rowMajorArr = b.NewArray() + defer rowMajorArr.Release() + + b.AppendValues( + []int32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + taat.rowMajorValid, + ) + colMajorArr = b.NewArray() + defer colMajorArr.Release() + + rowMajor = arrowTensor.NewInt32(rowMajorArr.Data(), []int64{2, 5}, nil, []string{"x", "y"}) + defer rowMajor.Release() + colMajor = arrowTensor.NewInt32(colMajorArr.Data(), []int64{2, 5}, []int64{int64(arrow.Int32SizeBytes), int64(arrow.Int32SizeBytes * 2)}, []string{"x", "y"}) + defer colMajor.Release() + case arrow.PrimitiveTypes.Int64: + b := arrowArray.NewInt64Builder(pool) + defer b.Release() + b.AppendValues( + []int64{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + taat.rowMajorValid, + ) + rowMajorArr = b.NewArray() + defer rowMajorArr.Release() + + b.AppendValues( + []int64{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + taat.rowMajorValid, + ) + colMajorArr = b.NewArray() + defer colMajorArr.Release() + + rowMajor = arrowTensor.NewInt64(rowMajorArr.Data(), []int64{2, 5}, nil, []string{"x", "y"}) + defer rowMajor.Release() + colMajor = arrowTensor.NewInt64(colMajorArr.Data(), []int64{2, 5}, []int64{int64(arrow.Int64SizeBytes), int64(arrow.Int64SizeBytes * 2)}, []string{"x", "y"}) + defer colMajor.Release() + case arrow.PrimitiveTypes.Uint8: + b := arrowArray.NewUint8Builder(pool) + defer b.Release() + b.AppendValues( + []uint8{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + taat.rowMajorValid, + ) + rowMajorArr = b.NewArray() + defer rowMajorArr.Release() + + b.AppendValues( + []uint8{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + taat.rowMajorValid, + ) + colMajorArr = b.NewArray() + defer colMajorArr.Release() + + rowMajor = arrowTensor.NewUint8(rowMajorArr.Data(), []int64{2, 5}, nil, []string{"x", "y"}) + defer rowMajor.Release() + colMajor = arrowTensor.NewUint8(colMajorArr.Data(), []int64{2, 5}, []int64{int64(arrow.Uint8SizeBytes), int64(arrow.Uint8SizeBytes * 2)}, []string{"x", "y"}) + defer colMajor.Release() + case arrow.PrimitiveTypes.Uint16: + b := arrowArray.NewUint16Builder(pool) + defer b.Release() + b.AppendValues( + []uint16{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + taat.rowMajorValid, + ) + rowMajorArr = b.NewArray() + defer rowMajorArr.Release() + + b.AppendValues( + []uint16{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + taat.rowMajorValid, + ) + colMajorArr = b.NewArray() + defer colMajorArr.Release() + + rowMajor = arrowTensor.NewUint16(rowMajorArr.Data(), []int64{2, 5}, nil, []string{"x", "y"}) + defer rowMajor.Release() + colMajor = arrowTensor.NewUint16(colMajorArr.Data(), []int64{2, 5}, []int64{int64(arrow.Uint16SizeBytes), int64(arrow.Uint16SizeBytes * 2)}, []string{"x", "y"}) + defer colMajor.Release() + case arrow.PrimitiveTypes.Uint32: + b := arrowArray.NewUint32Builder(pool) + defer b.Release() + b.AppendValues( + []uint32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + taat.rowMajorValid, + ) + rowMajorArr = b.NewArray() + defer rowMajorArr.Release() + + b.AppendValues( + []uint32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + taat.rowMajorValid, + ) + colMajorArr = b.NewArray() + defer colMajorArr.Release() + + rowMajor = arrowTensor.NewUint32(rowMajorArr.Data(), []int64{2, 5}, nil, []string{"x", "y"}) + defer rowMajor.Release() + colMajor = arrowTensor.NewUint32(colMajorArr.Data(), []int64{2, 5}, []int64{int64(arrow.Uint32SizeBytes), int64(arrow.Uint32SizeBytes * 2)}, []string{"x", "y"}) + defer colMajor.Release() + case arrow.PrimitiveTypes.Uint64: + b := arrowArray.NewUint64Builder(pool) + defer b.Release() + b.AppendValues( + []uint64{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + taat.rowMajorValid, + ) + rowMajorArr = b.NewArray() + defer rowMajorArr.Release() + + b.AppendValues( + []uint64{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + taat.rowMajorValid, + ) + colMajorArr = b.NewArray() + defer colMajorArr.Release() + + rowMajor = arrowTensor.NewUint64(rowMajorArr.Data(), []int64{2, 5}, nil, []string{"x", "y"}) + defer rowMajor.Release() + colMajor = arrowTensor.NewUint64(colMajorArr.Data(), []int64{2, 5}, []int64{int64(arrow.Uint64SizeBytes), int64(arrow.Uint64SizeBytes * 2)}, []string{"x", "y"}) + defer colMajor.Release() + case arrow.PrimitiveTypes.Float32: + b := arrowArray.NewFloat32Builder(pool) + defer b.Release() + b.AppendValues( + []float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + taat.rowMajorValid, + ) + rowMajorArr = b.NewArray() + defer rowMajorArr.Release() + + b.AppendValues( + []float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + taat.rowMajorValid, + ) + colMajorArr = b.NewArray() + defer colMajorArr.Release() + + rowMajor = arrowTensor.NewFloat32(rowMajorArr.Data(), []int64{2, 5}, nil, []string{"x", "y"}) + defer rowMajor.Release() + colMajor = arrowTensor.NewFloat32(colMajorArr.Data(), []int64{2, 5}, []int64{int64(arrow.Float32SizeBytes), int64(arrow.Float32SizeBytes * 2)}, []string{"x", "y"}) + defer colMajor.Release() + case arrow.PrimitiveTypes.Float64: + b := arrowArray.NewFloat64Builder(pool) + defer b.Release() + b.AppendValues( + []float64{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + taat.rowMajorValid, + ) + rowMajorArr = b.NewArray() + defer rowMajorArr.Release() + + b.AppendValues( + []float64{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + taat.rowMajorValid, + ) + colMajorArr = b.NewArray() + defer colMajorArr.Release() + + rowMajor = arrowTensor.NewFloat64(rowMajorArr.Data(), []int64{2, 5}, nil, []string{"x", "y"}) + defer rowMajor.Release() + colMajor = arrowTensor.NewFloat64(colMajorArr.Data(), []int64{2, 5}, []int64{int64(arrow.Float64SizeBytes), int64(arrow.Float64SizeBytes * 2)}, []string{"x", "y"}) + defer colMajor.Release() + default: + t.Errorf("DataType not supported in tests: %v", taat.dt) + } + + rowMajorT = FromArrowTensor(rowMajor) + colMajorT = FromArrowTensor(colMajor) + + assert.Equal(taat.rowMajorData, rowMajorT.Data(), "test %d: row major %v", i, taat.dt) + assert.Equal(len(taat.rowMajorValid), len(rowMajorT.Mask()), "test %d: row major %v mask length incorrect", i, taat.dt) + for i, invalid := range rowMajorT.Mask() { + assert.Equal(taat.rowMajorValid[i], !invalid, "test %d: row major %v mask value incorrect", i, taat.dt) + } + assert.True(colMajorT.Shape().Eq(taat.shape)) + + assert.Equal(taat.colMajorData, colMajorT.Data(), "test %d: column major %v", i, taat.dt) + assert.Equal(len(taat.colMajorValid), len(colMajorT.Mask()), "test %d: column major %v mask length incorrect", i, taat.dt) + for i, invalid := range colMajorT.Mask() { + assert.Equal(taat.colMajorValid[i], !invalid, "test %d: column major %v mask value incorrect", i, taat.dt) + } + assert.True(rowMajorT.Shape().Eq(taat.shape)) + } +} diff --git a/dense_format.go b/dense_format.go index ab4cfef..b7f8611 100644 --- a/dense_format.go +++ b/dense_format.go @@ -45,8 +45,8 @@ type fmtState struct { meta bool flat bool - ext bool - comp bool + ext bool // extended (i.e no elision) + comp bool // compact c rune // c is here mainly for struct packing reasons w, p int // width and precision @@ -121,10 +121,14 @@ func (f *fmtState) cleanFmt() string { // does the calculation for metadata func (f *fmtState) populate(t *Dense) { - if t.IsVector() { + switch { + case t.IsVector(): f.rows = 1 f.cols = t.Size() - } else { + case t.IsScalarEquiv(): + f.rows = 1 + f.cols = 1 + default: f.rows = t.Shape()[t.Dims()-2] f.cols = t.Shape()[t.Dims()-1] } @@ -153,6 +157,10 @@ func (f *fmtState) populate(t *Dense) { } func (f *fmtState) acceptableRune(d *Dense) { + if f.c == 'H' { + f.meta = true + return // accept H as header only + } switch d.t.Kind() { case reflect.Float64: switch f.c { @@ -248,6 +256,11 @@ func (f *fmtState) writeVElision() { // // Special care also needs be taken for the verb 's' - it prints a super compressed version of the tensor, only printing 4 cols and 4 rows. func (t *Dense) Format(s fmt.State, c rune) { + if c == 'i' { + fmt.Fprintf(s, "INFO:\n\tAP: %v\n\tOLD: %v\n\tTRANS %v\n\tENGINE: %T\n", t.AP, t.old, t.transposeWith, t.e) + return + } + f := newFmtState(s, c) if t.IsScalar() { o := f.originalFmt() @@ -273,6 +286,10 @@ func (t *Dense) Format(s fmt.State, c rune) { fmt.Fprintf(f, " %v %v\n", t.Shape(), t.Strides()) } + if f.c == 'H' { + return + } + if !t.IsNativelyAccessible() { fmt.Fprintf(f, "Inaccesible data") return @@ -348,14 +365,13 @@ func (t *Dense) Format(s fmt.State, c rune) { } // standard stuff - it := NewIterator(t.AP) + it := NewIterator(&t.AP) coord := it.Coord() firstRow := true firstVal := true var lastRow, lastCol int var expected int - for next, err := it.Next(); err == nil; next, err = it.Next() { if next < expected { continue @@ -377,6 +393,10 @@ func (t *Dense) Format(s fmt.State, c rune) { f.Write(rowVecStart) case t.IsVector(): f.Write(vecStart) + case t.IsScalarEquiv(): + for i := 0; i < t.Dims(); i++ { + f.Write(vecStart) + } default: f.Write(matFirstStart) } @@ -427,6 +447,11 @@ func (t *Dense) Format(s fmt.State, c rune) { case t.IsVector(): f.Write(vecEnd) return + case t.IsScalarEquiv(): + for i := 0; i < t.Dims(); i++ { + f.Write(vecEnd) + } + return case firstRow: f.Write(matFirstEnd) case eom: diff --git a/dense_format_test.go b/dense_format_test.go index b4b230c..50d4acb 100644 --- a/dense_format_test.go +++ b/dense_format_test.go @@ -21,6 +21,16 @@ func TestDense_Format(t *testing.T) { res = fmt.Sprintf("%3.3f", T) assert.Equal("3.140", res) + // Scalar-equiv (vector) + T = New(WithBacking([]float64{3.14}), WithShape(1)) + res = fmt.Sprintf("%3.3f", T) + assert.Equal("[3.140]", res) + + // Scalar-equiv (n-dimensional) + T = New(WithBacking([]float64{3.14}), WithShape(1, 1, 1, 1)) + res = fmt.Sprintf("%3.3f", T) + assert.Equal("[[[[3.140]]]]", res) + // short vector T = New(Of(Float64), WithShape(4)) res = fmt.Sprintf("%v", T) @@ -73,11 +83,13 @@ Matrix (2, 2) [2 1] // many cols, rows, compressed T = New(WithShape(16, 14), WithBacking(Range(Float64, 0, 16*14))) res = fmt.Sprintf("\n%s", T) + // this clunky string addition thing is because some editors like to trim whitespace. + // There should be two spaces after ` ⋮` . expected = ` ⎡ 0 1 ⋯ 12 13⎤ ⎢ 14 15 ⋯ 26 27⎥ - ⋮ -⎢196 197 ⋯ 208 209⎥ +` + ` ⋮ ` + ` +` + `⎢196 197 ⋯ 208 209⎥ ⎣210 211 ⋯ 222 223⎦ ` assert.Equal(expected, res, "expected %v. Got %v", expected, res) diff --git a/dense_generated.go b/dense_generated.go index 93ea20b..c9158fa 100644 --- a/dense_generated.go +++ b/dense_generated.go @@ -1,11 +1,15 @@ -// Code generated by genlib2. DO NOT EDIT. - package tensor -import "reflect" +import ( + "reflect" + + "gorgonia.org/dtype" +) + +// Code generated by genlib2. DO NOT EDIT. // Ones creates a *Dense with the provided shape and type -func Ones(dt Dtype, shape ...int) *Dense { +func Ones(dt dtype.Dtype, shape ...int) *Dense { d := recycledDense(dt, shape) switch d.t.Kind() { case reflect.Int: @@ -68,7 +72,7 @@ func Ones(dt Dtype, shape ...int) *Dense { // ⎢1 0 0 0⎥ // ⎢0 1 0 0⎥ // ⎣0 0 1 0⎦ -func I(dt Dtype, r, c, k int) *Dense { +func I(dt dtype.Dtype, r, c, k int) *Dense { ret := New(Of(dt), WithShape(r, c)) i := k if k < 0 { @@ -88,7 +92,7 @@ func I(dt Dtype, r, c, k int) *Dense { panic(err) } var nexts []int - iter := NewFlatIterator(s.AP) + iter := newFlatIterator(&s.AP) nexts, err = iter.Slice(rs{i, s.Size(), c + 1}) switch s.t.Kind() { diff --git a/dense_generated_test.go b/dense_generated_test.go index e87baa0..edd0850 100644 --- a/dense_generated_test.go +++ b/dense_generated_test.go @@ -1,15 +1,16 @@ -// Code generated by genlib2. DO NOT EDIT. - package tensor import ( "testing" "github.com/stretchr/testify/assert" + "gorgonia.org/dtype" ) +// Code generated by genlib2. DO NOT EDIT. + var onesTests = []struct { - of Dtype + of dtype.Dtype shape Shape correct interface{} }{ @@ -56,7 +57,7 @@ func TestOnes(t *testing.T) { // yes, it's a pun on eye tests, stop asking and go see your optometrist var eyeTests = []struct { - E Dtype + E dtype.Dtype R, C, K int correct interface{} diff --git a/dense_getset_test.go b/dense_getset_test.go index 8ab8e44..899e855 100644 --- a/dense_getset_test.go +++ b/dense_getset_test.go @@ -1,5 +1,3 @@ -// Code generated by genlib2. DO NOT EDIT. - package tensor import ( @@ -8,10 +6,13 @@ import ( "testing/quick" "github.com/stretchr/testify/assert" + "gorgonia.org/dtype" ) +// Code generated by genlib2. DO NOT EDIT. + var denseSetGetTests = []struct { - of Dtype + of dtype.Dtype data interface{} set interface{} @@ -48,7 +49,7 @@ func TestDense_setget(t *testing.T) { } var denseMemsetTests = []struct { - of Dtype + of dtype.Dtype data interface{} val interface{} shape Shape @@ -88,7 +89,7 @@ func TestDense_memset(t *testing.T) { } var denseZeroTests = []struct { - of Dtype + of dtype.Dtype data interface{} correct interface{} diff --git a/dense_io.go b/dense_io.go index 84896eb..374daf0 100644 --- a/dense_io.go +++ b/dense_io.go @@ -1,5 +1,3 @@ -// Code generated by genlib2. DO NOT EDIT. - package tensor import ( @@ -14,26 +12,143 @@ import ( "strconv" "strings" + flatbuffers "github.com/google/flatbuffers/go" "github.com/pkg/errors" + "gorgonia.org/dtype" + "gorgonia.org/tensor/internal/serialization/fb" + "gorgonia.org/tensor/internal/serialization/pb" ) +// Code generated by genlib2. DO NOT EDIT. + +/* GOB SERIALIZATION */ + +// GobEncode implements gob.GobEncoder +func (t *Dense) GobEncode() (p []byte, err error) { + var buf bytes.Buffer + encoder := gob.NewEncoder(&buf) + + if err = encoder.Encode(t.Shape()); err != nil { + return + } + + if err = encoder.Encode(t.Strides()); err != nil { + return + } + + if err = encoder.Encode(t.AP.o); err != nil { + return + } + + if err = encoder.Encode(t.AP.Δ); err != nil { + return + } + + if err = encoder.Encode(t.mask); err != nil { + return + } + + data := t.Data() + if err = encoder.Encode(&data); err != nil { + return + } + + return buf.Bytes(), err +} + +// GobDecode implements gob.GobDecoder +func (t *Dense) GobDecode(p []byte) (err error) { + buf := bytes.NewBuffer(p) + decoder := gob.NewDecoder(buf) + + var shape Shape + if err = decoder.Decode(&shape); err != nil { + return + } + + var strides []int + if err = decoder.Decode(&strides); err != nil { + return + } + + var o DataOrder + var tr Triangle + if err = decoder.Decode(&o); err == nil { + if err = decoder.Decode(&tr); err != nil { + return + } + } + + t.AP.Init(shape, strides) + t.AP.o = o + t.AP.Δ = tr + + var mask []bool + if err = decoder.Decode(&mask); err != nil { + return + } + + var data interface{} + if err = decoder.Decode(&data); err != nil { + return + } + + t.fromSlice(data) + t.addMask(mask) + t.fix() + if t.e == nil { + t.e = StdEng{} + } + return t.sanity() +} + +/* NPY SERIALIZATION */ + +var npyDescRE = regexp.MustCompile(`'descr':\s*'([^']*)'`) +var rowOrderRE = regexp.MustCompile(`'fortran_order':\s*(False|True)`) +var shapeRE = regexp.MustCompile(`'shape':\s*\(([^\(]*)\)`) + type binaryWriter struct { io.Writer - error + err error seq int } -func (w binaryWriter) w(x interface{}) { - if w.error != nil { +func (w *binaryWriter) w(x interface{}) { + if w.err != nil { return } - binary.Write(w, binary.LittleEndian, x) + w.err = binary.Write(w, binary.LittleEndian, x) w.seq++ } -func (w binaryWriter) Error() string { - return fmt.Sprintf("Error at sequence %d : %v", w.seq, w.error.Error()) +func (w *binaryWriter) Err() error { + if w.err == nil { + return nil + } + return errors.Wrapf(w.err, "Sequence %d", w.seq) +} + +type binaryReader struct { + io.Reader + err error + seq int +} + +func (r *binaryReader) Read(data interface{}) { + if r.err != nil { + return + } + r.err = binary.Read(r.Reader, binary.LittleEndian, data) + r.seq++ +} + +func (r *binaryReader) Err() error { + if r.err == nil { + return nil + } + return errors.Wrapf(r.err, "Sequence %d", r.seq) } // WriteNpy writes the *Tensor as a numpy compatible serialized file. @@ -49,12 +164,19 @@ func (w binaryWriter) Error() string { // If tensor is masked, invalid values are replaced by the default fill value. func (t *Dense) WriteNpy(w io.Writer) (err error) { var npdt string - if npdt, err = t.t.numpyDtype(); err != nil { + if npdt, err = t.t.NumpyDtype(); err != nil { return } - header := "{'descr': '<%v', 'fortran_order': False, 'shape': %v}" - header = fmt.Sprintf(header, npdt, t.Shape()) + var header string + if t.Dims() == 1 { + // when t is a 1D vector, numpy expects "(N,)" instead of "(N)" which t.Shape() returns. + header = "{'descr': '<%v', 'fortran_order': False, 'shape': (%d,)}" + header = fmt.Sprintf(header, npdt, t.Shape()[0]) + } else { + header = "{'descr': '<%v', 'fortran_order': False, 'shape': %v}" + header = fmt.Sprintf(header, npdt, t.Shape()) + } padding := 16 - ((10 + len(header)) % 16) if padding > 0 { header = header + strings.Repeat(" ", padding) @@ -64,8 +186,8 @@ func (t *Dense) WriteNpy(w io.Writer) (err error) { bw.w(byte(1)) // major version bw.w(byte(0)) // minor version bw.w(uint16(len(header))) // 4 bytes to denote header length - if bw.error != nil { - return bw + if err = bw.Err(); err != nil { + return err } bw.Write([]byte(header)) @@ -86,176 +208,57 @@ func (t *Dense) WriteNpy(w io.Writer) (err error) { } } - if bw.error != nil { - return bw - } - return nil -} - -// WriteCSV writes the *Dense to a CSV. It accepts an optional string formatting ("%v", "%f", etc...), which controls what is written to the CSV. -// If tensor is masked, invalid values are replaced by the default fill value. -func (t *Dense) WriteCSV(w io.Writer, formats ...string) (err error) { - // checks: - if !t.IsMatrix() { - // error - err = errors.Errorf("Cannot write *Dense to CSV. Expected number of dimensions: <=2, T has got %d dimensions (Shape: %v)", t.Dims(), t.Shape()) - return - } - format := "%v" - if len(formats) > 0 { - format = formats[0] - } - - cw := csv.NewWriter(w) - it := IteratorFromDense(t) - coord := it.Coord() - - // rows := t.Shape()[0] - cols := t.Shape()[1] - record := make([]string, 0, cols) - var i, k, lastCol int - isMasked := t.IsMasked() - fillval := t.FillValue() - fillstr := fmt.Sprintf(format, fillval) - for i, err = it.Next(); err == nil; i, err = it.Next() { - record = append(record, fmt.Sprintf(format, t.Get(i))) - if isMasked { - if t.mask[i] { - record[k] = fillstr - } - k++ - } - if lastCol == cols-1 { - if err = cw.Write(record); err != nil { - // TODO: wrap errors - return - } - cw.Flush() - record = record[:0] - } - - // cleanup - switch { - case t.IsRowVec(): - // lastRow = coord[len(coord)-2] - lastCol = coord[len(coord)-1] - case t.IsColVec(): - // lastRow = coord[len(coord)-1] - lastCol = coord[len(coord)-2] - case t.IsVector(): - lastCol = coord[len(coord)-1] - default: - // lastRow = coord[len(coord)-2] - lastCol = coord[len(coord)-1] - } - } - return nil -} - -// GobEncode implements gob.GobEncoder -func (t *Dense) GobEncode() (p []byte, err error) { - var buf bytes.Buffer - encoder := gob.NewEncoder(&buf) - - if err = encoder.Encode(t.Shape()); err != nil { - return - } - - if err = encoder.Encode(t.Strides()); err != nil { - return - } - - if err = encoder.Encode(t.AP.o); err != nil { - return - } - - if err = encoder.Encode(t.AP.Δ); err != nil { - return - } - - if err = encoder.Encode(t.mask); err != nil { - return - } - - data := t.Data() - if err = encoder.Encode(&data); err != nil { - return - } - - return buf.Bytes(), err + return bw.Err() } // ReadNpy reads NumPy formatted files into a *Dense func (t *Dense) ReadNpy(r io.Reader) (err error) { + br := binaryReader{Reader: r} var magic [6]byte - if _, err = r.Read(magic[:]); err != nil { - return - } - if string(magic[:]) != "\x93NUMPY" { - err = errors.Errorf("Not a numpy file. Got %q as the magic number instead", string(magic[:])) - return + if br.Read(magic[:]); string(magic[:]) != "\x93NUMPY" { + return errors.Errorf("Not a numpy file. Got %q as the magic number instead", string(magic[:])) } - var version byte - if err = binary.Read(r, binary.LittleEndian, &version); err != nil { - return - } - if version != 1 { - err = errors.New("Only verion 1.0 of numpy's serialization format is currently supported (65535 bytes ought to be enough for a header)") - return + var version, minor byte + if br.Read(&version); version != 1 { + return errors.New("Only verion 1.0 of numpy's serialization format is currently supported (65535 bytes ought to be enough for a header)") } - var minor byte - if err = binary.Read(r, binary.LittleEndian, &minor); err != nil { - return - } - if minor != 0 { - err = errors.New("Only verion 1.0 of numpy's serialization format is currently supported (65535 bytes ought to be enough for a header)") - return + if br.Read(&minor); minor != 0 { + return errors.New("Only verion 1.0 of numpy's serialization format is currently supported (65535 bytes ought to be enough for a header)") } var headerLen uint16 - if err = binary.Read(r, binary.LittleEndian, &headerLen); err != nil { - return - } - + br.Read(&headerLen) header := make([]byte, int(headerLen)) - if _, err = r.Read(header); err != nil { + br.Read(header) + if err = br.Err(); err != nil { return } - desc := regexp.MustCompile(`'descr':\s*'([^']*)'`) - match := desc.FindSubmatch(header) - if match == nil { - err = errors.New("No dtype information in npy file") - return + // extract stuff from header + var match [][]byte + if match = npyDescRE.FindSubmatch(header); match == nil { + return errors.New("No dtype information in npy file") } // TODO: check for endianness. For now we assume everything is little endian - var dt Dtype - if dt, err = fromNumpyDtype(string(match[1][1:])); err != nil { + if t.t, err = dtype.FromNumpyDtype(string(match[1][1:])); err != nil { return } - t.t = dt - rowOrder := regexp.MustCompile(`'fortran_order':\s*(False|True)`) - match = rowOrder.FindSubmatch(header) - if match == nil { - err = errors.New("No Row Order information found in the numpy file") - return + if match = rowOrderRE.FindSubmatch(header); match == nil { + return errors.New("No Row Order information found in the numpy file") } if string(match[1]) != "False" { - err = errors.New("Cannot yet read from Fortran Ordered Numpy files") - return + return errors.New("Cannot yet read from Fortran Ordered Numpy files") } - shpRe := regexp.MustCompile(`'shape':\s*\(([^\(]*)\)`) - match = shpRe.FindSubmatch(header) - if match == nil { - err = errors.New("No shape information found in npy file") - return + if match = shapeRE.FindSubmatch(header); match == nil { + return errors.New("No shape information found in npy file") } sizesStr := strings.Split(string(match[1]), ",") + var shape Shape for _, s := range sizesStr { s = strings.Trim(s, " ") @@ -273,163 +276,166 @@ func (t *Dense) ReadNpy(r io.Reader) (err error) { if t.e == nil { t.e = StdEng{} } - t.makeArray(size) switch t.t.Kind() { case reflect.Int: data := t.Ints() for i := 0; i < size; i++ { - if err = binary.Read(r, binary.LittleEndian, &data[i]); err != nil { - return - } + br.Read(&data[i]) } case reflect.Int8: data := t.Int8s() for i := 0; i < size; i++ { - if err = binary.Read(r, binary.LittleEndian, &data[i]); err != nil { - return - } + br.Read(&data[i]) } case reflect.Int16: data := t.Int16s() for i := 0; i < size; i++ { - if err = binary.Read(r, binary.LittleEndian, &data[i]); err != nil { - return - } + br.Read(&data[i]) } case reflect.Int32: data := t.Int32s() for i := 0; i < size; i++ { - if err = binary.Read(r, binary.LittleEndian, &data[i]); err != nil { - return - } + br.Read(&data[i]) } case reflect.Int64: data := t.Int64s() for i := 0; i < size; i++ { - if err = binary.Read(r, binary.LittleEndian, &data[i]); err != nil { - return - } + br.Read(&data[i]) } case reflect.Uint: data := t.Uints() for i := 0; i < size; i++ { - if err = binary.Read(r, binary.LittleEndian, &data[i]); err != nil { - return - } + br.Read(&data[i]) } case reflect.Uint8: data := t.Uint8s() for i := 0; i < size; i++ { - if err = binary.Read(r, binary.LittleEndian, &data[i]); err != nil { - return - } + br.Read(&data[i]) } case reflect.Uint16: data := t.Uint16s() for i := 0; i < size; i++ { - if err = binary.Read(r, binary.LittleEndian, &data[i]); err != nil { - return - } + br.Read(&data[i]) } case reflect.Uint32: data := t.Uint32s() for i := 0; i < size; i++ { - if err = binary.Read(r, binary.LittleEndian, &data[i]); err != nil { - return - } + br.Read(&data[i]) } case reflect.Uint64: data := t.Uint64s() for i := 0; i < size; i++ { - if err = binary.Read(r, binary.LittleEndian, &data[i]); err != nil { - return - } + br.Read(&data[i]) } case reflect.Float32: data := t.Float32s() for i := 0; i < size; i++ { - if err = binary.Read(r, binary.LittleEndian, &data[i]); err != nil { - return - } + br.Read(&data[i]) } case reflect.Float64: data := t.Float64s() for i := 0; i < size; i++ { - if err = binary.Read(r, binary.LittleEndian, &data[i]); err != nil { - return - } + br.Read(&data[i]) } case reflect.Complex64: data := t.Complex64s() for i := 0; i < size; i++ { - if err = binary.Read(r, binary.LittleEndian, &data[i]); err != nil { - return - } + br.Read(&data[i]) } case reflect.Complex128: data := t.Complex128s() for i := 0; i < size; i++ { - if err = binary.Read(r, binary.LittleEndian, &data[i]); err != nil { - return - } + br.Read(&data[i]) } } - t.AP = BorrowAP(len(shape)) + if err = br.Err(); err != nil { + return err + } + + t.AP.zeroWithDims(len(shape)) t.setShape(shape...) t.fix() return t.sanity() } -// GobDecode implements gob.GobDecoder -func (t *Dense) GobDecode(p []byte) (err error) { - buf := bytes.NewBuffer(p) - decoder := gob.NewDecoder(buf) - - var shape Shape - if err = decoder.Decode(&shape); err != nil { - return - } +/* CSV SERIALIZATION */ - var strides []int - if err = decoder.Decode(&strides); err != nil { - return +// WriteCSV writes the *Dense to a CSV. It accepts an optional string formatting ("%v", "%f", etc...), which controls what is written to the CSV. +// If tensor is masked, invalid values are replaced by the default fill value. +func (t *Dense) WriteCSV(w io.Writer, formats ...string) (err error) { + // checks: + if !t.IsMatrix() { + // error + err = errors.Errorf("Cannot write *Dense to CSV. Expected number of dimensions: <=2, T has got %d dimensions (Shape: %v)", t.Dims(), t.Shape()) + return } - - var o DataOrder - var tr Triangle - if err = decoder.Decode(&o); err == nil { - if err = decoder.Decode(&tr); err != nil { - return - } + format := "%v" + if len(formats) > 0 { + format = formats[0] } - t.AP = NewAP(shape, strides) - t.AP.o = o - t.AP.Δ = tr + cw := csv.NewWriter(w) + it := IteratorFromDense(t) + coord := it.Coord() - var mask []bool - if err = decoder.Decode(&mask); err != nil { - return - } + // rows := t.Shape()[0] + cols := t.Shape()[1] + record := make([]string, 0, cols) + var i, k, lastCol int + isMasked := t.IsMasked() + fillval := t.FillValue() + fillstr := fmt.Sprintf(format, fillval) + for i, err = it.Next(); err == nil; i, err = it.Next() { + record = append(record, fmt.Sprintf(format, t.Get(i))) + if isMasked { + if t.mask[i] { + record[k] = fillstr + } + k++ + } + if lastCol == cols-1 { + if err = cw.Write(record); err != nil { + // TODO: wrap errors + return + } + cw.Flush() + record = record[:0] + } - var data interface{} - if err = decoder.Decode(&data); err != nil { - return + // cleanup + switch { + case t.IsRowVec(): + // lastRow = coord[len(coord)-2] + lastCol = coord[len(coord)-1] + case t.IsColVec(): + // lastRow = coord[len(coord)-1] + lastCol = coord[len(coord)-2] + case t.IsVector(): + lastCol = coord[len(coord)-1] + default: + // lastRow = coord[len(coord)-2] + lastCol = coord[len(coord)-1] + } } - t.fromSlice(data) - t.addMask(mask) - t.fix() - return t.sanity() + return nil } -// convFromStrs conversts a []string to a slice of the Dtype provided -func convFromStrs(to Dtype, record []string) (interface{}, error) { +// convFromStrs converts a []string to a slice of the Dtype provided. It takes a provided backing slice. +// If into is nil, then a backing slice will be created. +func convFromStrs(to dtype.Dtype, record []string, into interface{}) (interface{}, error) { var err error switch to.Kind() { case reflect.Int: retVal := make([]int, len(record)) + var backing []int + if into == nil { + backing = make([]int, 0, len(record)) + } else { + backing = into.([]int) + } + for i, v := range record { var i64 int64 if i64, err = strconv.ParseInt(v, 10, 0); err != nil { @@ -437,9 +443,17 @@ func convFromStrs(to Dtype, record []string) (interface{}, error) { } retVal[i] = int(i64) } - return retVal, nil + backing = append(backing, retVal...) + return backing, nil case reflect.Int8: retVal := make([]int8, len(record)) + var backing []int8 + if into == nil { + backing = make([]int8, 0, len(record)) + } else { + backing = into.([]int8) + } + for i, v := range record { var i64 int64 if i64, err = strconv.ParseInt(v, 10, 8); err != nil { @@ -447,9 +461,17 @@ func convFromStrs(to Dtype, record []string) (interface{}, error) { } retVal[i] = int8(i64) } - return retVal, nil + backing = append(backing, retVal...) + return backing, nil case reflect.Int16: retVal := make([]int16, len(record)) + var backing []int16 + if into == nil { + backing = make([]int16, 0, len(record)) + } else { + backing = into.([]int16) + } + for i, v := range record { var i64 int64 if i64, err = strconv.ParseInt(v, 10, 16); err != nil { @@ -457,9 +479,17 @@ func convFromStrs(to Dtype, record []string) (interface{}, error) { } retVal[i] = int16(i64) } - return retVal, nil + backing = append(backing, retVal...) + return backing, nil case reflect.Int32: retVal := make([]int32, len(record)) + var backing []int32 + if into == nil { + backing = make([]int32, 0, len(record)) + } else { + backing = into.([]int32) + } + for i, v := range record { var i64 int64 if i64, err = strconv.ParseInt(v, 10, 32); err != nil { @@ -467,9 +497,17 @@ func convFromStrs(to Dtype, record []string) (interface{}, error) { } retVal[i] = int32(i64) } - return retVal, nil + backing = append(backing, retVal...) + return backing, nil case reflect.Int64: retVal := make([]int64, len(record)) + var backing []int64 + if into == nil { + backing = make([]int64, 0, len(record)) + } else { + backing = into.([]int64) + } + for i, v := range record { var i64 int64 if i64, err = strconv.ParseInt(v, 10, 64); err != nil { @@ -477,9 +515,17 @@ func convFromStrs(to Dtype, record []string) (interface{}, error) { } retVal[i] = int64(i64) } - return retVal, nil + backing = append(backing, retVal...) + return backing, nil case reflect.Uint: retVal := make([]uint, len(record)) + var backing []uint + if into == nil { + backing = make([]uint, 0, len(record)) + } else { + backing = into.([]uint) + } + for i, v := range record { var u uint64 if u, err = strconv.ParseUint(v, 10, 0); err != nil { @@ -487,9 +533,17 @@ func convFromStrs(to Dtype, record []string) (interface{}, error) { } retVal[i] = uint(u) } - return retVal, nil + backing = append(backing, retVal...) + return backing, nil case reflect.Uint8: retVal := make([]uint8, len(record)) + var backing []uint8 + if into == nil { + backing = make([]uint8, 0, len(record)) + } else { + backing = into.([]uint8) + } + for i, v := range record { var u uint64 if u, err = strconv.ParseUint(v, 10, 8); err != nil { @@ -497,9 +551,17 @@ func convFromStrs(to Dtype, record []string) (interface{}, error) { } retVal[i] = uint8(u) } - return retVal, nil + backing = append(backing, retVal...) + return backing, nil case reflect.Uint16: retVal := make([]uint16, len(record)) + var backing []uint16 + if into == nil { + backing = make([]uint16, 0, len(record)) + } else { + backing = into.([]uint16) + } + for i, v := range record { var u uint64 if u, err = strconv.ParseUint(v, 10, 16); err != nil { @@ -507,9 +569,17 @@ func convFromStrs(to Dtype, record []string) (interface{}, error) { } retVal[i] = uint16(u) } - return retVal, nil + backing = append(backing, retVal...) + return backing, nil case reflect.Uint32: retVal := make([]uint32, len(record)) + var backing []uint32 + if into == nil { + backing = make([]uint32, 0, len(record)) + } else { + backing = into.([]uint32) + } + for i, v := range record { var u uint64 if u, err = strconv.ParseUint(v, 10, 32); err != nil { @@ -517,9 +587,17 @@ func convFromStrs(to Dtype, record []string) (interface{}, error) { } retVal[i] = uint32(u) } - return retVal, nil + backing = append(backing, retVal...) + return backing, nil case reflect.Uint64: retVal := make([]uint64, len(record)) + var backing []uint64 + if into == nil { + backing = make([]uint64, 0, len(record)) + } else { + backing = into.([]uint64) + } + for i, v := range record { var u uint64 if u, err = strconv.ParseUint(v, 10, 64); err != nil { @@ -527,9 +605,17 @@ func convFromStrs(to Dtype, record []string) (interface{}, error) { } retVal[i] = uint64(u) } - return retVal, nil + backing = append(backing, retVal...) + return backing, nil case reflect.Float32: retVal := make([]float32, len(record)) + var backing []float32 + if into == nil { + backing = make([]float32, 0, len(record)) + } else { + backing = into.([]float32) + } + for i, v := range record { var f float64 if f, err = strconv.ParseFloat(v, 32); err != nil { @@ -537,15 +623,33 @@ func convFromStrs(to Dtype, record []string) (interface{}, error) { } retVal[i] = float32(f) } - return retVal, nil + backing = append(backing, retVal...) + return backing, nil case reflect.Float64: retVal := make([]float64, len(record)) + var backing []float64 + if into == nil { + backing = make([]float64, 0, len(record)) + } else { + backing = into.([]float64) + } + for i, v := range record { if retVal[i], err = strconv.ParseFloat(v, 64); err != nil { return nil, err } } - return retVal, nil + backing = append(backing, retVal...) + return backing, nil + case reflect.String: + var backing []string + if into == nil { + backing = make([]string, 0, len(record)) + } else { + backing = into.([]string) + } + backing = append(backing, record...) + return backing, nil default: return nil, errors.Errorf(methodNYI, "convFromStrs", to) } @@ -564,307 +668,221 @@ func (t *Dense) ReadCSV(r io.Reader, opts ...FuncOpt) (err error) { cr := csv.NewReader(r) var record []string - var row interface{} var rows, cols int - - switch as.Kind() { - case reflect.Int: - var backing []int - for { - record, err = cr.Read() - if err == io.EOF { - break - } - - if err != nil { - return - } - - if row, err = convFromStrs(Int, record); err != nil { - return - } - backing = append(backing, row.([]int)...) - cols = len(record) - rows++ - } - t.fromSlice(backing) - t.AP = new(AP) - t.AP.SetShape(rows, cols) - return nil - case reflect.Int8: - var backing []int8 - for { - record, err = cr.Read() - if err == io.EOF { - break - } - - if err != nil { - return - } - - if row, err = convFromStrs(Int8, record); err != nil { - return - } - backing = append(backing, row.([]int8)...) - cols = len(record) - rows++ - } - t.fromSlice(backing) - t.AP = new(AP) - t.AP.SetShape(rows, cols) - return nil - case reflect.Int16: - var backing []int16 - for { - record, err = cr.Read() - if err == io.EOF { - break - } - - if err != nil { - return - } - - if row, err = convFromStrs(Int16, record); err != nil { - return - } - backing = append(backing, row.([]int16)...) - cols = len(record) - rows++ - } - t.fromSlice(backing) - t.AP = new(AP) - t.AP.SetShape(rows, cols) - return nil - case reflect.Int32: - var backing []int32 - for { - record, err = cr.Read() - if err == io.EOF { - break - } - - if err != nil { - return - } - - if row, err = convFromStrs(Int32, record); err != nil { - return - } - backing = append(backing, row.([]int32)...) - cols = len(record) - rows++ - } - t.fromSlice(backing) - t.AP = new(AP) - t.AP.SetShape(rows, cols) - return nil - case reflect.Int64: - var backing []int64 - for { - record, err = cr.Read() - if err == io.EOF { - break - } - - if err != nil { - return - } - - if row, err = convFromStrs(Int64, record); err != nil { - return - } - backing = append(backing, row.([]int64)...) - cols = len(record) - rows++ + var backing interface{} + for { + record, err = cr.Read() + if err == io.EOF { + break + } else if err != nil { + return } - t.fromSlice(backing) - t.AP = new(AP) - t.AP.SetShape(rows, cols) - return nil - case reflect.Uint: - var backing []uint - for { - record, err = cr.Read() - if err == io.EOF { - break - } - - if err != nil { - return - } - - if row, err = convFromStrs(Uint, record); err != nil { - return - } - backing = append(backing, row.([]uint)...) - cols = len(record) - rows++ + if backing, err = convFromStrs(as, record, backing); err != nil { + return } - t.fromSlice(backing) - t.AP = new(AP) - t.AP.SetShape(rows, cols) - return nil - case reflect.Uint8: - var backing []uint8 - for { - record, err = cr.Read() - if err == io.EOF { - break - } + cols = len(record) + rows++ + } + t.fromSlice(backing) + t.AP.zero() + t.AP.SetShape(rows, cols) + return nil + return errors.Errorf("not yet handled") +} - if err != nil { - return - } +/* FB SERIALIZATION */ - if row, err = convFromStrs(Uint8, record); err != nil { - return - } - backing = append(backing, row.([]uint8)...) - cols = len(record) - rows++ - } - t.fromSlice(backing) - t.AP = new(AP) - t.AP.SetShape(rows, cols) - return nil - case reflect.Uint16: - var backing []uint16 - for { - record, err = cr.Read() - if err == io.EOF { - break - } +// FBEncode encodes to a byte slice using flatbuffers. +// +// Only natively accessible data can be encided +func (t *Dense) FBEncode() ([]byte, error) { + builder := flatbuffers.NewBuilder(1024) + + fb.DenseStartShapeVector(builder, len(t.shape)) + for i := len(t.shape) - 1; i >= 0; i-- { + builder.PrependInt32(int32(t.shape[i])) + } + shape := builder.EndVector(len(t.shape)) + + fb.DenseStartStridesVector(builder, len(t.strides)) + for i := len(t.strides) - 1; i >= 0; i-- { + builder.PrependInt32(int32(t.strides[i])) + } + strides := builder.EndVector(len(t.strides)) + + var o uint32 + switch { + case t.o.IsRowMajor() && t.o.IsContiguous(): + o = 0 + case t.o.IsRowMajor() && !t.o.IsContiguous(): + o = 1 + case t.o.IsColMajor() && t.o.IsContiguous(): + o = 2 + case t.o.IsColMajor() && !t.o.IsContiguous(): + o = 3 + } + + var triangle int32 + switch t.Δ { + case NotTriangle: + triangle = fb.TriangleNOT_TRIANGLE + case Upper: + triangle = fb.TriangleUPPER + case Lower: + triangle = fb.TriangleLOWER + case Symmetric: + triangle = fb.TriangleSYMMETRIC + } + + dt := builder.CreateString(t.Dtype().String()) + data := t.byteSlice() + + fb.DenseStartDataVector(builder, len(data)) + for i := len(data) - 1; i >= 0; i-- { + builder.PrependUint8(data[i]) + } + databyte := builder.EndVector(len(data)) + + fb.DenseStart(builder) + fb.DenseAddShape(builder, shape) + fb.DenseAddStrides(builder, strides) + fb.DenseAddO(builder, o) + fb.DenseAddT(builder, triangle) + fb.DenseAddType(builder, dt) + fb.DenseAddData(builder, databyte) + serialized := fb.DenseEnd(builder) + builder.Finish(serialized) + + return builder.FinishedBytes(), nil +} - if err != nil { - return - } +// FBDecode decodes a byteslice from a flatbuffer table into a *Dense +func (t *Dense) FBDecode(buf []byte) error { + serialized := fb.GetRootAsDense(buf, 0) - if row, err = convFromStrs(Uint16, record); err != nil { - return - } - backing = append(backing, row.([]uint16)...) - cols = len(record) - rows++ - } - t.fromSlice(backing) - t.AP = new(AP) - t.AP.SetShape(rows, cols) - return nil - case reflect.Uint32: - var backing []uint32 - for { - record, err = cr.Read() - if err == io.EOF { - break - } + o := serialized.O() + switch o { + case 0: + t.o = 0 + case 1: + t.o = MakeDataOrder(NonContiguous) + case 2: + t.o = MakeDataOrder(ColMajor) + case 3: + t.o = MakeDataOrder(ColMajor, NonContiguous) + } - if err != nil { - return - } + tri := serialized.T() + switch tri { + case fb.TriangleNOT_TRIANGLE: + t.Δ = NotTriangle + case fb.TriangleUPPER: + t.Δ = Upper + case fb.TriangleLOWER: + t.Δ = Lower + case fb.TriangleSYMMETRIC: + t.Δ = Symmetric + } - if row, err = convFromStrs(Uint32, record); err != nil { - return - } - backing = append(backing, row.([]uint32)...) - cols = len(record) - rows++ - } - t.fromSlice(backing) - t.AP = new(AP) - t.AP.SetShape(rows, cols) - return nil - case reflect.Uint64: - var backing []uint64 - for { - record, err = cr.Read() - if err == io.EOF { - break - } + t.shape = Shape(BorrowInts(serialized.ShapeLength())) + for i := 0; i < serialized.ShapeLength(); i++ { + t.shape[i] = int(int32(serialized.Shape(i))) + } - if err != nil { - return - } + t.strides = BorrowInts(serialized.StridesLength()) + for i := 0; i < serialized.ShapeLength(); i++ { + t.strides[i] = int(serialized.Strides(i)) + } + typ := string(serialized.Type()) + dt, err := dtype.FindByName(typ) + if err != nil { + return errors.Wrap(err, "Failed to decode FlatBuffers") + } + t.t = dt - if row, err = convFromStrs(Uint64, record); err != nil { - return - } - backing = append(backing, row.([]uint64)...) - cols = len(record) - rows++ - } - t.fromSlice(backing) - t.AP = new(AP) - t.AP.SetShape(rows, cols) - return nil - case reflect.Float32: - var backing []float32 - for { - record, err = cr.Read() - if err == io.EOF { - break - } + if t.e == nil { + t.e = StdEng{} + } + t.makeArray(t.shape.TotalSize()) - if err != nil { - return - } + // allocated data. Now time to actually copy over the data + db := t.byteSlice() + copy(db, serialized.DataBytes()) + t.fix() + return t.sanity() +} - if row, err = convFromStrs(Float32, record); err != nil { - return - } - backing = append(backing, row.([]float32)...) - cols = len(record) - rows++ - } - t.fromSlice(backing) - t.AP = new(AP) - t.AP.SetShape(rows, cols) - return nil - case reflect.Float64: - var backing []float64 - for { - record, err = cr.Read() - if err == io.EOF { - break - } +/* PB SERIALIZATION */ + +// PBEncode encodes the Dense into a protobuf byte slice. +func (t *Dense) PBEncode() ([]byte, error) { + var toSerialize pb.Dense + toSerialize.Shape = make([]int32, len(t.shape)) + for i, v := range t.shape { + toSerialize.Shape[i] = int32(v) + } + toSerialize.Strides = make([]int32, len(t.strides)) + for i, v := range t.strides { + toSerialize.Strides[i] = int32(v) + } + + switch { + case t.o.IsRowMajor() && t.o.IsContiguous(): + toSerialize.O = pb.RowMajorContiguous + case t.o.IsRowMajor() && !t.o.IsContiguous(): + toSerialize.O = pb.RowMajorNonContiguous + case t.o.IsColMajor() && t.o.IsContiguous(): + toSerialize.O = pb.ColMajorContiguous + case t.o.IsColMajor() && !t.o.IsContiguous(): + toSerialize.O = pb.ColMajorNonContiguous + } + toSerialize.T = pb.Triangle(t.Δ) + toSerialize.Type = t.t.String() + data := t.byteSlice() + toSerialize.Data = make([]byte, len(data)) + copy(toSerialize.Data, data) + return toSerialize.Marshal() +} - if err != nil { - return - } +// PBDecode unmarshalls a protobuf byteslice into a *Dense. +func (t *Dense) PBDecode(buf []byte) error { + var toSerialize pb.Dense + if err := toSerialize.Unmarshal(buf); err != nil { + return err + } + t.shape = make(Shape, len(toSerialize.Shape)) + for i, v := range toSerialize.Shape { + t.shape[i] = int(v) + } + t.strides = make([]int, len(toSerialize.Strides)) + for i, v := range toSerialize.Strides { + t.strides[i] = int(v) + } - if row, err = convFromStrs(Float64, record); err != nil { - return - } - backing = append(backing, row.([]float64)...) - cols = len(record) - rows++ - } - t.fromSlice(backing) - t.AP = new(AP) - t.AP.SetShape(rows, cols) - return nil - case reflect.String: - var backing []string - for { - record, err = cr.Read() - if err == io.EOF { - break - } + switch toSerialize.O { + case pb.RowMajorContiguous: + case pb.RowMajorNonContiguous: + t.o = MakeDataOrder(NonContiguous) + case pb.ColMajorContiguous: + t.o = MakeDataOrder(ColMajor) + case pb.ColMajorNonContiguous: + t.o = MakeDataOrder(ColMajor, NonContiguous) + } + t.Δ = Triangle(toSerialize.T) + typ := string(toSerialize.Type) + dt, err := dtype.FindByName(typ) + if err != nil { + return errors.Wrap(err, "Failed to decode ProtoBuf") + } + t.t = dt - if err != nil { - return - } - backing = append(backing, record...) - cols = len(record) - rows++ - } - t.fromSlice(backing) - t.AP = new(AP) - t.AP.SetShape(rows, cols) - return nil - default: - return errors.Errorf("%v not yet handled", as) + if t.e == nil { + t.e = StdEng{} } - return errors.Errorf("not yet handled") + t.makeArray(t.shape.TotalSize()) + + // allocated data. Now time to actually copy over the data + db := t.byteSlice() + copy(db, toSerialize.Data) + return t.sanity() } diff --git a/dense_io_test.go b/dense_io_test.go index 01de3f0..d2a6548 100644 --- a/dense_io_test.go +++ b/dense_io_test.go @@ -3,31 +3,63 @@ package tensor import ( "bytes" "encoding/gob" + "io/ioutil" "os" "os/exec" + "regexp" "testing" "github.com/stretchr/testify/assert" ) func TestSaveLoadNumpy(t *testing.T) { - if os.Getenv("TRAVISTEST") == "true" { - t.Skip("skipping test; This is being run on TravisCI") + if os.Getenv("CI_NO_PYTHON") == "true" { + t.Skip("skipping test; This is being run on a CI tool that does not have Python") } assert := assert.New(t) T := New(WithShape(2, 2), WithBacking([]float64{1, 5, 10, -1})) + // also checks the 1D Vector. + T1D := New(WithShape(4), WithBacking([]float64{1, 5, 10, -1})) + f, _ := os.OpenFile("test.npy", os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0644) + f1D, _ := os.OpenFile("test1D.npy", os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0644) + T.WriteNpy(f) f.Close() - script := "import numpy as np\nx = np.load('test.npy')\nprint(x)" + T1D.WriteNpy(f1D) + f1D.Close() + + defer func() { + // cleanup + err := os.Remove("test.npy") + if err != nil { + t.Error(err) + } + + err = os.Remove("test1D.npy") + if err != nil { + t.Error(err) + } + }() - cmd := exec.Command("python2") + script := "import numpy as np\nx = np.load('test.npy')\nprint(x)\nx = np.load('test1D.npy')\nprint(x)" + // Configurable python command, in order to be able to use python or python3 + pythonCommand := os.Getenv("PYTHON_COMMAND") + if pythonCommand == "" { + pythonCommand = "python" + } + + cmd := exec.Command(pythonCommand) stdin, err := cmd.StdinPipe() if err != nil { t.Error(err) } + stderr, err := cmd.StderrPipe() + if err != nil { + t.Error(err) + } go func() { defer stdin.Close() @@ -39,22 +71,22 @@ func TestSaveLoadNumpy(t *testing.T) { if err = cmd.Start(); err != nil { t.Error(err) + t.Logf("Do you have a python with numpy installed? You can change the python interpreter by setting the environment variable PYTHON_COMMAND. Current value: PYTHON_COMMAND=%s", pythonCommand) } - if err := cmd.Wait(); err != nil { - t.Error(err) + importError := `ImportError: No module named numpy` + slurpErr, _ := ioutil.ReadAll(stderr) + if ok, _ := regexp.Match(importError, slurpErr); ok { + t.Skipf("Skipping numpy test. It would appear that you do not have Numpy installed.") } - expected := "[[ 1. 5.]\n [ 10. -1.]]\n" - - if buf.String() != expected { - t.Errorf("Did not successfully read numpy file, \n%q\n%q", buf.String(), expected) + if err := cmd.Wait(); err != nil { + t.Errorf("%q", err.Error()) } - // cleanup - err = os.Remove("test.npy") - if err != nil { - t.Error(err) + expected := `\[\[\s*1\.\s*5\.\]\n \[\s*10\.\s*-1\.\]\]\n` + if ok, _ := regexp.Match(expected, buf.Bytes()); !ok { + t.Errorf("Did not successfully read numpy file, \n%q\n%q", buf.String(), expected) } // ok now to test if it can read @@ -68,6 +100,17 @@ func TestSaveLoadNumpy(t *testing.T) { assert.Equal(T.Strides(), T2.Strides()) assert.Equal(T.Data(), T2.Data()) + // ok now to test if it can read 1D + T1D2 := new(Dense) + buf = new(bytes.Buffer) + T1D.WriteNpy(buf) + if err = T1D2.ReadNpy(buf); err != nil { + t.Fatal(err) + } + assert.Equal(T1D.Shape(), T1D2.Shape()) + assert.Equal(T1D.Strides(), T1D2.Strides()) + assert.Equal(T1D.Data(), T1D2.Data()) + // try with masked array. masked elements should be filled with default value T.ResetMask(false) T.mask[0] = true @@ -82,6 +125,21 @@ func TestSaveLoadNumpy(t *testing.T) { data := T.Float64s() data[0] = T.FillValue().(float64) assert.Equal(data, T3.Data()) + + // try with 1D masked array. masked elements should be filled with default value + T1D.ResetMask(false) + T1D.mask[0] = true + T1D3 := new(Dense) + buf = new(bytes.Buffer) + T1D.WriteNpy(buf) + if err = T1D3.ReadNpy(buf); err != nil { + t.Fatal(err) + } + assert.Equal(T1D.Shape(), T1D3.Shape()) + assert.Equal(T1D.Strides(), T1D3.Strides()) + data = T1D.Float64s() + data[0] = T1D.FillValue().(float64) + assert.Equal(data, T1D3.Data()) } func TestSaveLoadCSV(t *testing.T) { @@ -204,5 +262,52 @@ func TestDense_GobEncodeDecode(t *testing.T) { assert.Equal(T.mask, T3.mask) } +} + +func TestDense_FBEncodeDecode(t *testing.T) { + assert := assert.New(t) + for _, gtd := range serializationTestData { + T := New(WithShape(2, 2), WithBacking(gtd)) + buf, err := T.FBEncode() + if err != nil { + t.Errorf("UNPOSSIBLE!: %v", err) + continue + } + + T2 := new(Dense) + if err = T2.FBDecode(buf); err != nil { + t.Errorf("Error while decoding %v: %v", gtd, err) + continue + } + + assert.Equal(T.Shape(), T2.Shape()) + assert.Equal(T.Strides(), T2.Strides()) + assert.Equal(T.Data(), T2.Data()) + + // TODO: MASKED ARRAY + } +} +func TestDense_PBEncodeDecode(t *testing.T) { + assert := assert.New(t) + for _, gtd := range serializationTestData { + T := New(WithShape(2, 2), WithBacking(gtd)) + buf, err := T.PBEncode() + if err != nil { + t.Errorf("UNPOSSIBLE!: %v", err) + continue + } + + T2 := new(Dense) + if err = T2.PBDecode(buf); err != nil { + t.Errorf("Error while decoding %v: %v", gtd, err) + continue + } + + assert.Equal(T.Shape(), T2.Shape()) + assert.Equal(T.Strides(), T2.Strides()) + assert.Equal(T.Data(), T2.Data()) + + // TODO: MASKED ARRAY + } } diff --git a/dense_linalg.go b/dense_linalg.go index ca07663..10eb936 100644 --- a/dense_linalg.go +++ b/dense_linalg.go @@ -1,13 +1,16 @@ package tensor -import "github.com/pkg/errors" +import ( + "github.com/pkg/errors" + "gorgonia.org/dtype" +) // Trace returns the trace of the matrix (i.e. the sum of the diagonal elements). It only works for matrices func (t *Dense) Trace() (retVal interface{}, err error) { e := t.e - + ctx := ctxFromEngine(e) if tracer, ok := e.(Tracer); ok { - return tracer.Trace(t) + return tracer.Trace(ctx, t) } return nil, errors.Errorf("Engine %T does not support Trace", e) } @@ -15,8 +18,8 @@ func (t *Dense) Trace() (retVal interface{}, err error) { // Inner performs a dot product on two vectors. If t or other are not vectors, it will return an error. func (t *Dense) Inner(other Tensor) (retVal interface{}, err error) { // check that the data is a float - if !isFloat(t.t) { - return nil, errors.Errorf(unsupportedDtype, t.t, "Inner") + if err = dtype.TypeClassCheck(t.t, dtype.FloatComplex); err != nil { + return nil, errors.Wrapf(err, unsupportedDtype, t.t, "Inner") } // check both are vectors @@ -31,13 +34,14 @@ func (t *Dense) Inner(other Tensor) (retVal interface{}, err error) { } e := t.e + ctx := ctxFromEngine(e) switch ip := e.(type) { case InnerProderF32: - return ip.Inner(t, other) + return ip.Inner(ctx, t, other) case InnerProderF64: - return ip.Inner(t, other) + return ip.Inner(ctx, t, other) case InnerProder: - return ip.Inner(t, other) + return ip.Inner(ctx, t, other) } return nil, errors.Errorf("Engine does not support Inner()") @@ -80,19 +84,22 @@ func (t *Dense) MatVecMul(other Tensor, opts ...FuncOpt) (retVal *Dense, err err // check whether retVal has the same size as the resulting matrix would be: mx1 fo := ParseFuncOpts(opts...) defer returnOpOpt(fo) - if retVal, err = handleReuse(fo.Reuse(), expectedShape); err != nil { + if retVal, err = handleReuse(fo.Reuse(), expectedShape, fo.Safe()); err != nil { err = errors.Wrapf(err, opFail, "MatVecMul") return } if retVal == nil { - retVal = recycledDense(t.t, expectedShape) + retVal = recycledDense(t.t, expectedShape, WithEngine(t.e)) + if t.o.IsColMajor() { + AsFortran(nil)(retVal) + } } + ctx := fo.Context() e := t.e - if mvm, ok := e.(MatVecMuler); ok { - if err = mvm.MatVecMul(t, other, retVal); err != nil { + if err = mvm.MatVecMul(ctx, t, other, retVal); err != nil { return nil, errors.Wrapf(err, opFail, "MatVecMul") } return handleIncr(retVal, fo.Reuse(), fo.Incr(), expectedShape) @@ -126,19 +133,22 @@ func (t *Dense) MatMul(other Tensor, opts ...FuncOpt) (retVal *Dense, err error) fo := ParseFuncOpts(opts...) defer returnOpOpt(fo) - if retVal, err = handleReuse(fo.Reuse(), expectedShape); err != nil { + if retVal, err = handleReuse(fo.Reuse(), expectedShape, fo.Safe()); err != nil { err = errors.Wrapf(err, opFail, "MatMul") return } if retVal == nil { - retVal = recycledDense(t.t, expectedShape) + retVal = recycledDense(t.t, expectedShape, WithEngine(t.e)) + if t.o.IsColMajor() { + AsFortran(nil)(retVal) + } } + ctx := fo.Context() e := t.e - if mm, ok := e.(MatMuler); ok { - if err = mm.MatMul(t, other, retVal); err != nil { + if err = mm.MatMul(ctx, t, other, retVal); err != nil { return } return handleIncr(retVal, fo.Reuse(), fo.Incr(), expectedShape) @@ -149,12 +159,6 @@ func (t *Dense) MatMul(other Tensor, opts ...FuncOpt) (retVal *Dense, err error) // Outer finds the outer product of two vectors func (t *Dense) Outer(other Tensor, opts ...FuncOpt) (retVal *Dense, err error) { - // check both are vectors - if !t.Shape().IsVector() || !other.Shape().IsVector() { - err = errors.Errorf("Outer only works when there are two vectors. t's shape: %v. other's shape: %v", t.Shape(), other.Shape()) - return - } - m := t.Size() n := other.Size() @@ -163,21 +167,25 @@ func (t *Dense) Outer(other Tensor, opts ...FuncOpt) (retVal *Dense, err error) fo := ParseFuncOpts(opts...) defer returnOpOpt(fo) - if retVal, err = handleReuse(fo.Reuse(), expectedShape); err != nil { + if retVal, err = handleReuse(fo.Reuse(), expectedShape, fo.Safe()); err != nil { err = errors.Wrapf(err, opFail, "Outer") return } if retVal == nil { - retVal = recycledDense(t.t, expectedShape) + retVal = recycledDense(t.t, expectedShape, WithEngine(t.e)) + if t.o.IsColMajor() { + AsFortran(nil)(retVal) + } } + ctx := fo.Context() e := t.e // DGER does not have any beta. So the values have to be zeroed first if the tensor is to be reused retVal.Zero() if op, ok := e.(OuterProder); ok { - if err = op.Outer(t, other, retVal); err != nil { + if err = op.Outer(ctx, t, other, retVal); err != nil { return nil, errors.Wrapf(err, opFail, "engine.uter") } return handleIncr(retVal, fo.Reuse(), fo.Incr(), expectedShape) @@ -310,7 +318,6 @@ func (t *Dense) TensorMul(other Tensor, axesA, axesB []int) (retVal *Dense, err return } doOther.Transpose() - if err = doOther.Reshape(newShapeO...); err != nil { return } @@ -346,10 +353,10 @@ func (t *Dense) TensorMul(other Tensor, axesA, axesB []int) (retVal *Dense, err // In the future, when gonum/lapack fully supports float32, we'll look into rewriting this func (t *Dense) SVD(uv, full bool) (s, u, v *Dense, err error) { e := t.Engine() - + ctx := ctxFromEngine(e) if svder, ok := e.(SVDer); ok { var sT, uT, vT Tensor - if sT, uT, vT, err = svder.SVD(t, uv, full); err != nil { + if sT, uT, vT, err = svder.SVD(ctx, t, uv, full); err != nil { return nil, nil, nil, errors.Wrap(err, "Error while performing *Dense.SVD") } if s, err = assertDense(sT); err != nil { @@ -371,20 +378,22 @@ func (t *Dense) SVD(uv, full bool) (s, u, v *Dense, err error) { /* UTILITY FUNCTIONS */ // handleReuse extracts a *Dense from Tensor, and checks the shape of the reuse Tensor -func handleReuse(reuse Tensor, expectedShape Shape) (retVal *Dense, err error) { +func handleReuse(reuse Tensor, expectedShape Shape, safe bool) (retVal *Dense, err error) { if reuse != nil { if retVal, err = assertDense(reuse); err != nil { err = errors.Wrapf(err, opFail, "handling reuse") return } - - if err = reuseCheckShape(retVal, expectedShape); err != nil { + if !safe { + return + } + if err = checkFixShape(retVal, expectedShape); err != nil { err = errors.Wrapf(err, "Unable to process reuse *Dense Tensor. Shape error.") return } return } - return + return nil, nil } // handleIncr is the cleanup step for when there is an Tensor to increment. If the result tensor is the same as the reuse Tensor, the result tensor gets returned to the pool @@ -402,7 +411,7 @@ func handleIncr(res *Dense, reuse, incr Tensor, expectedShape Shape) (retVal *De return } - if err = typeclassCheck(incrD.t, numberTypes); err != nil { + if err = dtype.TypeClassCheck(incrD.t, dtype.Number); err != nil { err = errors.Wrapf(err, "handleIncr only handles Number types. Got %v instead", incrD.t) return } diff --git a/dense_linalg_test.go b/dense_linalg_test.go index bfd316c..17b6fcd 100644 --- a/dense_linalg_test.go +++ b/dense_linalg_test.go @@ -10,6 +10,7 @@ import ( type linalgTest struct { a, b interface{} shapeA, shapeB Shape + transA, transB bool reuse, incr interface{} shapeR, shapeI Shape @@ -118,89 +119,94 @@ func TestDense_Inner(t *testing.T) { var matVecMulTests = []linalgTest{ // Float64s - {Range(Float64, 0, 6), Range(Float64, 0, 3), Shape{2, 3}, Shape{3}, + {Range(Float64, 0, 6), Range(Float64, 0, 3), Shape{2, 3}, Shape{3}, false, false, Range(Float64, 52, 54), Range(Float64, 100, 102), Shape{2}, Shape{2}, []float64{5, 14}, []float64{105, 115}, []float64{110, 129}, Shape{2}, false, false, false}, - {Range(Float64, 0, 6), Range(Float64, 0, 3), Shape{2, 3}, Shape{3, 1}, + {Range(Float64, 0, 6), Range(Float64, 0, 3), Shape{2, 3}, Shape{3, 1}, false, false, Range(Float64, 52, 54), Range(Float64, 100, 102), Shape{2}, Shape{2}, []float64{5, 14}, []float64{105, 115}, []float64{110, 129}, Shape{2}, false, false, false}, - {Range(Float64, 0, 6), Range(Float64, 0, 3), Shape{2, 3}, Shape{1, 3}, + {Range(Float64, 0, 6), Range(Float64, 0, 3), Shape{2, 3}, Shape{1, 3}, false, false, Range(Float64, 52, 54), Range(Float64, 100, 102), Shape{2}, Shape{2}, []float64{5, 14}, []float64{105, 115}, []float64{110, 129}, Shape{2}, false, false, false}, + // float64s with transposed matrix + {Range(Float64, 0, 6), Range(Float64, 0, 2), Shape{2, 3}, Shape{2}, true, false, + Range(Float64, 52, 55), Range(Float64, 100, 103), Shape{3}, Shape{3}, + []float64{3, 4, 5}, []float64{103, 105, 107}, []float64{106, 109, 112}, Shape{3}, false, false, false}, + // Float32s - {Range(Float32, 0, 6), Range(Float32, 0, 3), Shape{2, 3}, Shape{3}, + {Range(Float32, 0, 6), Range(Float32, 0, 3), Shape{2, 3}, Shape{3}, false, false, Range(Float32, 52, 54), Range(Float32, 100, 102), Shape{2}, Shape{2}, []float32{5, 14}, []float32{105, 115}, []float32{110, 129}, Shape{2}, false, false, false}, - {Range(Float32, 0, 6), Range(Float32, 0, 3), Shape{2, 3}, Shape{3, 1}, + {Range(Float32, 0, 6), Range(Float32, 0, 3), Shape{2, 3}, Shape{3, 1}, false, false, Range(Float32, 52, 54), Range(Float32, 100, 102), Shape{2}, Shape{2}, []float32{5, 14}, []float32{105, 115}, []float32{110, 129}, Shape{2}, false, false, false}, - {Range(Float32, 0, 6), Range(Float32, 0, 3), Shape{2, 3}, Shape{1, 3}, + {Range(Float32, 0, 6), Range(Float32, 0, 3), Shape{2, 3}, Shape{1, 3}, false, false, Range(Float32, 52, 54), Range(Float32, 100, 102), Shape{2}, Shape{2}, []float32{5, 14}, []float32{105, 115}, []float32{110, 129}, Shape{2}, false, false, false}, // stupids : unpossible shapes (wrong A) - {Range(Float64, 0, 6), Range(Float64, 0, 3), Shape{6}, Shape{3}, + {Range(Float64, 0, 6), Range(Float64, 0, 3), Shape{6}, Shape{3}, false, false, Range(Float64, 52, 54), Range(Float64, 100, 102), Shape{2}, Shape{2}, []float64{5, 14}, []float64{105, 115}, []float64{110, 129}, Shape{2}, true, false, false}, //stupids: bad A shape - {Range(Float64, 0, 8), Range(Float64, 0, 3), Shape{4, 2}, Shape{3}, + {Range(Float64, 0, 8), Range(Float64, 0, 3), Shape{4, 2}, Shape{3}, false, false, Range(Float64, 52, 54), Range(Float64, 100, 102), Shape{2}, Shape{2}, []float64{5, 14}, []float64{105, 115}, []float64{110, 129}, Shape{2}, true, false, false}, //stupids: bad B shape - {Range(Float64, 0, 6), Range(Float64, 0, 6), Shape{2, 3}, Shape{3, 2}, + {Range(Float64, 0, 6), Range(Float64, 0, 6), Shape{2, 3}, Shape{3, 2}, false, false, Range(Float64, 52, 54), Range(Float64, 100, 102), Shape{2}, Shape{2}, []float64{5, 14}, []float64{105, 115}, []float64{110, 129}, Shape{2}, true, false, false}, //stupids: bad reuse - {Range(Float64, 0, 6), Range(Float64, 0, 3), Shape{2, 3}, Shape{3}, + {Range(Float64, 0, 6), Range(Float64, 0, 3), Shape{2, 3}, Shape{3}, false, false, Range(Float64, 52, 55), Range(Float64, 100, 102), Shape{3}, Shape{2}, []float64{5, 14}, []float64{105, 115}, []float64{110, 129}, Shape{2}, false, false, true}, //stupids: bad incr shape - {Range(Float64, 0, 6), Range(Float64, 0, 3), Shape{2, 3}, Shape{3}, + {Range(Float64, 0, 6), Range(Float64, 0, 3), Shape{2, 3}, Shape{3}, false, false, Range(Float64, 52, 54), Range(Float64, 100, 105), Shape{2}, Shape{5}, []float64{5, 14}, []float64{105, 115}, []float64{110, 129}, Shape{2}, false, true, false}, // stupids: type mismatch A and B - {Range(Float64, 0, 6), Range(Float32, 0, 3), Shape{2, 3}, Shape{3}, + {Range(Float64, 0, 6), Range(Float32, 0, 3), Shape{2, 3}, Shape{3}, false, false, Range(Float64, 52, 54), Range(Float64, 100, 103), Shape{2}, Shape{3}, []float64{5, 14}, []float64{105, 115}, []float64{110, 129}, Shape{2}, true, false, false}, // stupids: type mismatch A and B - {Range(Float32, 0, 6), Range(Float64, 0, 3), Shape{2, 3}, Shape{3}, + {Range(Float32, 0, 6), Range(Float64, 0, 3), Shape{2, 3}, Shape{3}, false, false, Range(Float64, 52, 54), Range(Float64, 100, 103), Shape{2}, Shape{3}, []float64{5, 14}, []float64{105, 115}, []float64{110, 129}, Shape{2}, true, false, false}, // stupids: type mismatch A and B - {Range(Float64, 0, 6), Range(Float32, 0, 3), Shape{2, 3}, Shape{3}, + {Range(Float64, 0, 6), Range(Float32, 0, 3), Shape{2, 3}, Shape{3}, false, false, Range(Float64, 52, 54), Range(Float64, 100, 103), Shape{2}, Shape{3}, []float64{5, 14}, []float64{105, 115}, []float64{110, 129}, Shape{2}, true, false, false}, // stupids: type mismatch A and B - {Range(Float32, 0, 6), Range(Float64, 0, 3), Shape{2, 3}, Shape{3}, + {Range(Float32, 0, 6), Range(Float64, 0, 3), Shape{2, 3}, Shape{3}, false, false, Range(Float64, 52, 54), Range(Float64, 100, 103), Shape{2}, Shape{3}, []float64{5, 14}, []float64{105, 115}, []float64{110, 129}, Shape{2}, true, false, false}, // stupids: type mismatch A and B (non-Float) - {Range(Float64, 0, 6), Range(Int, 0, 3), Shape{2, 3}, Shape{3}, + {Range(Float64, 0, 6), Range(Int, 0, 3), Shape{2, 3}, Shape{3}, false, false, Range(Float64, 52, 54), Range(Float64, 100, 103), Shape{2}, Shape{3}, []float64{5, 14}, []float64{105, 115}, []float64{110, 129}, Shape{2}, true, false, false}, // stupids: type mismatch, reuse - {Range(Float64, 0, 6), Range(Float64, 0, 3), Shape{2, 3}, Shape{3}, + {Range(Float64, 0, 6), Range(Float64, 0, 3), Shape{2, 3}, Shape{3}, false, false, Range(Float32, 52, 54), Range(Float64, 100, 102), Shape{2}, Shape{2}, []float64{5, 14}, []float64{105, 115}, []float64{110, 129}, Shape{2}, false, false, true}, // stupids: type mismatch, incr - {Range(Float64, 0, 6), Range(Float64, 0, 3), Shape{2, 3}, Shape{3}, + {Range(Float64, 0, 6), Range(Float64, 0, 3), Shape{2, 3}, Shape{3}, false, false, Range(Float64, 52, 54), Range(Float32, 100, 103), Shape{2}, Shape{3}, []float64{5, 14}, []float64{105, 115}, []float64{110, 129}, Shape{2}, false, true, false}, // stupids: type mismatch, incr not a Number - {Range(Float64, 0, 6), Range(Float64, 0, 3), Shape{2, 3}, Shape{3}, + {Range(Float64, 0, 6), Range(Float64, 0, 3), Shape{2, 3}, Shape{3}, false, false, Range(Float64, 52, 54), []bool{true, true, true}, Shape{2}, Shape{3}, []float64{5, 14}, []float64{105, 115}, []float64{110, 129}, Shape{2}, false, true, false}, } @@ -211,12 +217,19 @@ func TestDense_MatVecMul(t *testing.T) { a := New(WithBacking(mvmt.a), WithShape(mvmt.shapeA...)) b := New(WithBacking(mvmt.b), WithShape(mvmt.shapeB...)) + if mvmt.transA { + if err := a.T(); err != nil { + t.Error(err) + continue + } + } T, err := a.MatVecMul(b) if checkErr(t, mvmt.err, err, "Safe", i) { continue } assert.True(mvmt.correctShape.Eq(T.Shape())) + assert.True(T.DataOrder().IsRowMajor()) assert.Equal(mvmt.correct, T.Data()) // incr @@ -227,6 +240,7 @@ func TestDense_MatVecMul(t *testing.T) { } assert.True(mvmt.correctShape.Eq(T.Shape())) + assert.True(T.DataOrder().IsRowMajor()) assert.Equal(mvmt.correctIncr, T.Data()) // reuse @@ -237,6 +251,7 @@ func TestDense_MatVecMul(t *testing.T) { } assert.True(mvmt.correctShape.Eq(T.Shape())) + assert.True(T.DataOrder().IsRowMajor()) assert.Equal(mvmt.correct, T.Data()) // reuse AND incr @@ -251,89 +266,89 @@ func TestDense_MatVecMul(t *testing.T) { var matMulTests = []linalgTest{ // Float64s - {Range(Float64, 0, 6), Range(Float64, 0, 6), Shape{2, 3}, Shape{3, 2}, + {Range(Float64, 0, 6), Range(Float64, 0, 6), Shape{2, 3}, Shape{3, 2}, false, false, Range(Float64, 52, 56), Range(Float64, 100, 104), Shape{2, 2}, Shape{2, 2}, []float64{10, 13, 28, 40}, []float64{110, 114, 130, 143}, []float64{120, 127, 158, 183}, Shape{2, 2}, false, false, false}, // Float32s - {Range(Float32, 0, 6), Range(Float32, 0, 6), Shape{2, 3}, Shape{3, 2}, + {Range(Float32, 0, 6), Range(Float32, 0, 6), Shape{2, 3}, Shape{3, 2}, false, false, Range(Float32, 52, 56), Range(Float32, 100, 104), Shape{2, 2}, Shape{2, 2}, []float32{10, 13, 28, 40}, []float32{110, 114, 130, 143}, []float32{120, 127, 158, 183}, Shape{2, 2}, false, false, false}, // Edge cases - Row Vecs (Float64) - {Range(Float64, 0, 2), Range(Float64, 0, 3), Shape{2, 1}, Shape{1, 3}, + {Range(Float64, 0, 2), Range(Float64, 0, 3), Shape{2, 1}, Shape{1, 3}, false, false, Range(Float64, 10, 16), Range(Float64, 100, 106), Shape{2, 3}, Shape{2, 3}, []float64{0, 0, 0, 0, 1, 2}, []float64{100, 101, 102, 103, 105, 107}, []float64{100, 101, 102, 103, 106, 109}, Shape{2, 3}, false, false, false}, - {Range(Float64, 0, 2), Range(Float64, 0, 6), Shape{1, 2}, Shape{2, 3}, + {Range(Float64, 0, 2), Range(Float64, 0, 6), Shape{1, 2}, Shape{2, 3}, false, false, Range(Float64, 10, 13), Range(Float64, 100, 103), Shape{1, 3}, Shape{1, 3}, []float64{3, 4, 5}, []float64{103, 105, 107}, []float64{106, 109, 112}, Shape{1, 3}, false, false, false}, - {Range(Float64, 0, 2), Range(Float64, 0, 2), Shape{1, 2}, Shape{2, 1}, + {Range(Float64, 0, 2), Range(Float64, 0, 2), Shape{1, 2}, Shape{2, 1}, false, false, Range(Float64, 0, 1), Range(Float64, 100, 101), Shape{1, 1}, Shape{1, 1}, []float64{1}, []float64{101}, []float64{102}, Shape{1, 1}, false, false, false}, // Edge cases - Row Vecs (Float32) - {Range(Float32, 0, 2), Range(Float32, 0, 3), Shape{2, 1}, Shape{1, 3}, + {Range(Float32, 0, 2), Range(Float32, 0, 3), Shape{2, 1}, Shape{1, 3}, false, false, Range(Float32, 10, 16), Range(Float32, 100, 106), Shape{2, 3}, Shape{2, 3}, []float32{0, 0, 0, 0, 1, 2}, []float32{100, 101, 102, 103, 105, 107}, []float32{100, 101, 102, 103, 106, 109}, Shape{2, 3}, false, false, false}, - {Range(Float32, 0, 2), Range(Float32, 0, 6), Shape{1, 2}, Shape{2, 3}, + {Range(Float32, 0, 2), Range(Float32, 0, 6), Shape{1, 2}, Shape{2, 3}, false, false, Range(Float32, 10, 13), Range(Float32, 100, 103), Shape{1, 3}, Shape{1, 3}, []float32{3, 4, 5}, []float32{103, 105, 107}, []float32{106, 109, 112}, Shape{1, 3}, false, false, false}, - {Range(Float32, 0, 2), Range(Float32, 0, 2), Shape{1, 2}, Shape{2, 1}, + {Range(Float32, 0, 2), Range(Float32, 0, 2), Shape{1, 2}, Shape{2, 1}, false, false, Range(Float32, 0, 1), Range(Float32, 100, 101), Shape{1, 1}, Shape{1, 1}, []float32{1}, []float32{101}, []float32{102}, Shape{1, 1}, false, false, false}, // stupids - bad shape (not matrices): - {Range(Float64, 0, 6), Range(Float64, 0, 6), Shape{2, 3}, Shape{6}, + {Range(Float64, 0, 6), Range(Float64, 0, 6), Shape{2, 3}, Shape{6}, false, false, Range(Float64, 52, 56), Range(Float64, 100, 104), Shape{2, 2}, Shape{2, 2}, []float64{10, 13, 28, 40}, []float64{110, 114, 130, 143}, []float64{120, 127, 158, 183}, Shape{2, 2}, true, false, false}, // stupids - bad shape (incompatible shapes): - {Range(Float64, 0, 6), Range(Float64, 0, 6), Shape{2, 3}, Shape{6, 1}, + {Range(Float64, 0, 6), Range(Float64, 0, 6), Shape{2, 3}, Shape{6, 1}, false, false, Range(Float64, 52, 56), Range(Float64, 100, 104), Shape{2, 2}, Shape{2, 2}, []float64{10, 13, 28, 40}, []float64{110, 114, 130, 143}, []float64{120, 127, 158, 183}, Shape{2, 2}, true, false, false}, // stupids - bad shape (bad reuse shape): - {Range(Float64, 0, 6), Range(Float64, 0, 6), Shape{2, 3}, Shape{3, 2}, + {Range(Float64, 0, 6), Range(Float64, 0, 6), Shape{2, 3}, Shape{3, 2}, false, false, Range(Float64, 52, 57), Range(Float64, 100, 104), Shape{5}, Shape{2, 2}, []float64{10, 13, 28, 40}, []float64{110, 114, 130, 143}, []float64{120, 127, 158, 183}, Shape{2, 2}, false, false, true}, // stupids - bad shape (bad incr shape): - {Range(Float64, 0, 6), Range(Float64, 0, 6), Shape{2, 3}, Shape{3, 2}, + {Range(Float64, 0, 6), Range(Float64, 0, 6), Shape{2, 3}, Shape{3, 2}, false, false, Range(Float64, 52, 56), Range(Float64, 100, 104), Shape{2, 2}, Shape{4}, []float64{10, 13, 28, 40}, []float64{110, 114, 130, 143}, []float64{120, 127, 158, 183}, Shape{2, 2}, false, true, false}, // stupids - type mismatch (a,b) - {Range(Float64, 0, 6), Range(Float32, 0, 6), Shape{2, 3}, Shape{3, 2}, + {Range(Float64, 0, 6), Range(Float32, 0, 6), Shape{2, 3}, Shape{3, 2}, false, false, Range(Float64, 52, 56), Range(Float64, 100, 104), Shape{2, 2}, Shape{2, 2}, []float64{10, 13, 28, 40}, []float64{110, 114, 130, 143}, []float64{120, 127, 158, 183}, Shape{2, 2}, true, false, false}, // stupids - type mismatch (a,b) - {Range(Float32, 0, 6), Range(Float64, 0, 6), Shape{2, 3}, Shape{3, 2}, + {Range(Float32, 0, 6), Range(Float64, 0, 6), Shape{2, 3}, Shape{3, 2}, false, false, Range(Float64, 52, 56), Range(Float64, 100, 104), Shape{2, 2}, Shape{2, 2}, []float64{10, 13, 28, 40}, []float64{110, 114, 130, 143}, []float64{120, 127, 158, 183}, Shape{2, 2}, true, false, false}, // stupids type mismatch (b not float) - {Range(Float64, 0, 6), Range(Int, 0, 6), Shape{2, 3}, Shape{3, 2}, + {Range(Float64, 0, 6), Range(Int, 0, 6), Shape{2, 3}, Shape{3, 2}, false, false, Range(Float64, 52, 56), Range(Float64, 100, 104), Shape{2, 2}, Shape{2, 2}, []float64{10, 13, 28, 40}, []float64{110, 114, 130, 143}, []float64{120, 127, 158, 183}, Shape{2, 2}, true, false, false}, // stupids type mismatch (a not float) - {Range(Int, 0, 6), Range(Int, 0, 6), Shape{2, 3}, Shape{3, 2}, + {Range(Int, 0, 6), Range(Int, 0, 6), Shape{2, 3}, Shape{3, 2}, false, false, Range(Float64, 52, 56), Range(Float64, 100, 104), Shape{2, 2}, Shape{2, 2}, []float64{10, 13, 28, 40}, []float64{110, 114, 130, 143}, []float64{120, 127, 158, 183}, Shape{2, 2}, true, false, false}, // stupids: type mismatch (incr) - {Range(Float64, 0, 6), Range(Float64, 0, 6), Shape{2, 3}, Shape{3, 2}, + {Range(Float64, 0, 6), Range(Float64, 0, 6), Shape{2, 3}, Shape{3, 2}, false, false, Range(Float64, 52, 56), Range(Float32, 100, 104), Shape{2, 2}, Shape{2, 2}, []float64{10, 13, 28, 40}, []float64{110, 114, 130, 143}, []float64{120, 127, 158, 183}, Shape{2, 2}, false, true, false}, // stupids: type mismatch (reuse) - {Range(Float64, 0, 6), Range(Float64, 0, 6), Shape{2, 3}, Shape{3, 2}, + {Range(Float64, 0, 6), Range(Float64, 0, 6), Shape{2, 3}, Shape{3, 2}, false, false, Range(Float32, 52, 56), Range(Float64, 100, 104), Shape{2, 2}, Shape{2, 2}, []float64{10, 13, 28, 40}, []float64{110, 114, 130, 143}, []float64{120, 127, 158, 183}, Shape{2, 2}, false, false, true}, // stupids: type mismatch (reuse) - {Range(Float32, 0, 6), Range(Float32, 0, 6), Shape{2, 3}, Shape{3, 2}, + {Range(Float32, 0, 6), Range(Float32, 0, 6), Shape{2, 3}, Shape{3, 2}, false, false, Range(Float64, 52, 56), Range(Float32, 100, 104), Shape{2, 2}, Shape{2, 2}, []float32{10, 13, 28, 40}, []float32{110, 114, 130, 143}, []float32{120, 127, 158, 183}, Shape{2, 2}, false, false, true}, } @@ -382,55 +397,57 @@ func TestDense_MatMul(t *testing.T) { var outerTests = []linalgTest{ // Float64s - {Range(Float64, 0, 3), Range(Float64, 0, 3), Shape{3}, Shape{3}, + {Range(Float64, 0, 3), Range(Float64, 0, 3), Shape{3}, Shape{3}, false, false, Range(Float64, 52, 61), Range(Float64, 100, 109), Shape{3, 3}, Shape{3, 3}, []float64{0, 0, 0, 0, 1, 2, 0, 2, 4}, []float64{100, 101, 102, 103, 105, 107, 106, 109, 112}, []float64{100, 101, 102, 103, 106, 109, 106, 111, 116}, Shape{3, 3}, false, false, false}, // Float32s - {Range(Float32, 0, 3), Range(Float32, 0, 3), Shape{3}, Shape{3}, + {Range(Float32, 0, 3), Range(Float32, 0, 3), Shape{3}, Shape{3}, false, false, Range(Float32, 52, 61), Range(Float32, 100, 109), Shape{3, 3}, Shape{3, 3}, []float32{0, 0, 0, 0, 1, 2, 0, 2, 4}, []float32{100, 101, 102, 103, 105, 107, 106, 109, 112}, []float32{100, 101, 102, 103, 106, 109, 106, 111, 116}, Shape{3, 3}, false, false, false}, + /* TODO: this test is no longer valid with the new impl of outer // stupids - a or b not vector - {Range(Float64, 0, 3), Range(Float64, 0, 6), Shape{3}, Shape{3, 2}, + {Range(Float64, 0, 3), Range(Float64, 0, 6), Shape{3}, Shape{3, 2}, false, false, Range(Float64, 52, 61), Range(Float64, 100, 109), Shape{3, 3}, Shape{3, 3}, []float64{0, 0, 0, 0, 1, 2, 0, 2, 4}, []float64{100, 101, 102, 103, 105, 107, 106, 109, 112}, []float64{100, 101, 102, 103, 106, 109, 106, 111, 116}, Shape{3, 3}, true, false, false}, + */ // stupids - bad incr shape - {Range(Float64, 0, 3), Range(Float64, 0, 3), Shape{3}, Shape{3}, + {Range(Float64, 0, 3), Range(Float64, 0, 3), Shape{3}, Shape{3}, false, false, Range(Float64, 52, 61), Range(Float64, 100, 106), Shape{3, 3}, Shape{3, 2}, []float64{0, 0, 0, 0, 1, 2, 0, 2, 4}, []float64{100, 101, 102, 103, 105, 107, 106, 109, 112}, []float64{100, 101, 102, 103, 106, 109, 106, 111, 116}, Shape{3, 3}, false, true, false}, // stupids - bad reuse shape - {Range(Float64, 0, 3), Range(Float64, 0, 3), Shape{3}, Shape{3}, + {Range(Float64, 0, 3), Range(Float64, 0, 3), Shape{3}, Shape{3}, false, false, Range(Float64, 52, 58), Range(Float64, 100, 109), Shape{3, 2}, Shape{3, 3}, []float64{0, 0, 0, 0, 1, 2, 0, 2, 4}, []float64{100, 101, 102, 103, 105, 107, 106, 109, 112}, []float64{100, 101, 102, 103, 106, 109, 106, 111, 116}, Shape{3, 3}, false, false, true}, // stupids - b not Float - {Range(Float64, 0, 3), Range(Int, 0, 3), Shape{3}, Shape{3}, + {Range(Float64, 0, 3), Range(Int, 0, 3), Shape{3}, Shape{3}, false, false, Range(Float64, 52, 61), Range(Float64, 100, 109), Shape{3, 3}, Shape{3, 3}, []float64{0, 0, 0, 0, 1, 2, 0, 2, 4}, []float64{100, 101, 102, 103, 105, 107, 106, 109, 112}, []float64{100, 101, 102, 103, 106, 109, 106, 111, 116}, Shape{3, 3}, true, false, false}, // stupids - a not Float - {Range(Int, 0, 3), Range(Float64, 0, 3), Shape{3}, Shape{3}, + {Range(Int, 0, 3), Range(Float64, 0, 3), Shape{3}, Shape{3}, false, false, Range(Float64, 52, 61), Range(Float64, 100, 109), Shape{3, 3}, Shape{3, 3}, []float64{0, 0, 0, 0, 1, 2, 0, 2, 4}, []float64{100, 101, 102, 103, 105, 107, 106, 109, 112}, []float64{100, 101, 102, 103, 106, 109, 106, 111, 116}, Shape{3, 3}, true, false, false}, // stupids - a-b type mismatch - {Range(Float64, 0, 3), Range(Float32, 0, 3), Shape{3}, Shape{3}, + {Range(Float64, 0, 3), Range(Float32, 0, 3), Shape{3}, Shape{3}, false, false, Range(Float64, 52, 61), Range(Float64, 100, 109), Shape{3, 3}, Shape{3, 3}, []float64{0, 0, 0, 0, 1, 2, 0, 2, 4}, []float64{100, 101, 102, 103, 105, 107, 106, 109, 112}, []float64{100, 101, 102, 103, 106, 109, 106, 111, 116}, Shape{3, 3}, true, false, false}, // stupids a-b type mismatch - {Range(Float32, 0, 3), Range(Float64, 0, 3), Shape{3}, Shape{3}, + {Range(Float32, 0, 3), Range(Float64, 0, 3), Shape{3}, Shape{3}, false, false, Range(Float64, 52, 61), Range(Float64, 100, 109), Shape{3, 3}, Shape{3, 3}, []float64{0, 0, 0, 0, 1, 2, 0, 2, 4}, []float64{100, 101, 102, 103, 105, 107, 106, 109, 112}, []float64{100, 101, 102, 103, 106, 109, 106, 111, 116}, Shape{3, 3}, true, false, false}, diff --git a/dense_mapreduce.go b/dense_mapreduce.go index 6072614..2677fe7 100644 --- a/dense_mapreduce.go +++ b/dense_mapreduce.go @@ -11,7 +11,7 @@ func (t *Dense) Apply(fn interface{}, opts ...FuncOpt) (retVal Tensor, err error if m, ok := e.(Mapper); ok { return m.Map(fn, t, opts...) } - return nil, errors.Errorf("Execution engine for %v not a mapper", t) + return nil, errors.Errorf("Execution engine %T for %v not a mapper", e, t) } // Reduce applies a reduction function and reduces the values along the given axis. diff --git a/dense_mask_filling.go b/dense_mask_filling.go index f5d45c7..a31b5aa 100644 --- a/dense_mask_filling.go +++ b/dense_mask_filling.go @@ -72,7 +72,7 @@ func (t *Dense) Filled(val ...interface{}) (interface{}, error) { for i := range sliceList { tt, err := tc.Slice(nil, sliceList[i]) if err != nil { - ts := tt.(*Dense) + ts := tt.(DenseView) ts.Memset(fillval) } } @@ -107,7 +107,7 @@ func (t *Dense) FilledInplace(val ...interface{}) (interface{}, error) { for i := range sliceList { tt, err := t.Slice(nil, sliceList[i]) if err != nil { - ts := tt.(*Dense) + ts := tt.(DenseView) ts.Memset(fillval) } } diff --git a/dense_mask_inspection.go b/dense_mask_inspection.go index d2e7843..7e1c30c 100644 --- a/dense_mask_inspection.go +++ b/dense_mask_inspection.go @@ -1,10 +1,12 @@ package tensor +import "gorgonia.org/dtype" + type maskedReduceFn func(Tensor) interface{} // MaskedReduce applies a reduction function of type maskedReduceFn to mask, and returns // either an int, or another array -func MaskedReduce(t *Dense, retType Dtype, fn maskedReduceFn, axis ...int) interface{} { +func MaskedReduce(t *Dense, retType dtype.Dtype, fn maskedReduceFn, axis ...int) interface{} { if len(axis) == 0 || t.IsVector() { return fn(t) } @@ -18,7 +20,7 @@ func MaskedReduce(t *Dense, retType Dtype, fn maskedReduceFn, axis ...int) inter // calculate shape of tensor to be returned slices[ax] = makeRS(0, 0) tt, _ := t.Slice(slices...) - ts := tt.(*Dense) + ts := MustGetDense(tt) retVal := NewDense(retType, ts.shape) //retVal is array to be returned it := NewIterator(retVal.Info()) @@ -37,7 +39,7 @@ func MaskedReduce(t *Dense, retType Dtype, fn maskedReduceFn, axis ...int) inter } } tt, _ = t.Slice(slices...) - ts = tt.(*Dense) + ts = MustGetDense(tt) retVal.SetAt(fn(ts), coord...) } diff --git a/dense_mask_inspection_test.go b/dense_mask_inspection_test.go index 7bd118f..ea3574f 100644 --- a/dense_mask_inspection_test.go +++ b/dense_mask_inspection_test.go @@ -124,7 +124,7 @@ func TestMaskedFindContiguous(t *testing.T) { T.ResetMask(true) for i := range sliceList { tt, _ := T.Slice(nil, sliceList[i]) - ts := tt.(*Dense) + ts := MustGetDense(tt) ts.ResetMask(false) } retSL = T.FlatNotMaskedContiguous() @@ -137,7 +137,7 @@ func TestMaskedFindContiguous(t *testing.T) { T.ResetMask(false) for i := range sliceList { tt, _ := T.Slice(nil, sliceList[i]) - ts := tt.(*Dense) + ts := MustGetDense(tt) ts.ResetMask(true) } retSL = T.FlatMaskedContiguous() @@ -158,7 +158,7 @@ func TestMaskedFindEdges(t *testing.T) { T.ResetMask(false) for i := range sliceList { tt, _ := T.Slice(nil, sliceList[i]) - ts := tt.(*Dense) + ts := MustGetDense(tt) ts.ResetMask(true) } start, end := T.FlatNotMaskedEdges() @@ -169,7 +169,7 @@ func TestMaskedFindEdges(t *testing.T) { T.ResetMask(true) for i := range sliceList { tt, _ := T.Slice(nil, sliceList[i]) - ts := tt.(*Dense) + ts := MustGetDense(tt) ts.ResetMask(false) } start, end = T.FlatMaskedEdges() diff --git a/dense_maskcmp_methods.go b/dense_maskcmp_methods.go index 4cc3d95..d4b415a 100644 --- a/dense_maskcmp_methods.go +++ b/dense_maskcmp_methods.go @@ -1,5 +1,3 @@ -// Code generated by genlib2. DO NOT EDIT. - package tensor import ( @@ -9,6 +7,8 @@ import ( "github.com/pkg/errors" ) +// Code generated by genlib2. DO NOT EDIT. + /* MaskedEqual */ // MaskedEqual sets the mask to true where the corresponding data is equal to val diff --git a/dense_maskcmp_methods_test.go b/dense_maskcmp_methods_test.go index 94e365c..e48e89c 100644 --- a/dense_maskcmp_methods_test.go +++ b/dense_maskcmp_methods_test.go @@ -1,13 +1,14 @@ -// Code generated by genlib2. DO NOT EDIT. - package tensor import ( + "fmt" "testing" "github.com/stretchr/testify/assert" ) +// Code generated by genlib2. DO NOT EDIT. + /* MaskedEqual */ func TestDense_MaskedEqual_I(t *testing.T) { @@ -604,26 +605,26 @@ func TestDense_MaskedEqual_Str(t *testing.T) { assert.False(T.IsMasked()) data := T.Strings() for i := range data { - data[i] = string(i) + data[i] = fmt.Sprint(i) } - T.MaskedEqual(string(0)) + T.MaskedEqual(fmt.Sprint(0)) assert.True(T.IsMasked()) - T.MaskedEqual(string(1)) + T.MaskedEqual(fmt.Sprint(1)) assert.True(T.mask[0] && T.mask[1]) - T.MaskedNotEqual(string(2)) + T.MaskedNotEqual(fmt.Sprint(2)) assert.False(T.mask[2] && !(T.mask[0])) T.ResetMask() - T.MaskedInside(string(1), string(22)) + T.MaskedInside(fmt.Sprint(1), fmt.Sprint(22)) assert.True(!T.mask[0] && !T.mask[23] && T.mask[1] && T.mask[22]) T.ResetMask() - T.MaskedOutside(string(1), string(22)) + T.MaskedOutside(fmt.Sprint(1), fmt.Sprint(22)) assert.True(T.mask[0] && T.mask[23] && !T.mask[1] && !T.mask[22]) T.ResetMask() for i := 0; i < 5; i++ { - T.MaskedEqual(string(i * 10)) + T.MaskedEqual(fmt.Sprint(i * 10)) } it := IteratorFromDense(T) @@ -1244,26 +1245,26 @@ func TestDense_MaskedNotEqual_Str(t *testing.T) { assert.False(T.IsMasked()) data := T.Strings() for i := range data { - data[i] = string(i) + data[i] = fmt.Sprint(i) } - T.MaskedEqual(string(0)) + T.MaskedEqual(fmt.Sprint(0)) assert.True(T.IsMasked()) - T.MaskedEqual(string(1)) + T.MaskedEqual(fmt.Sprint(1)) assert.True(T.mask[0] && T.mask[1]) - T.MaskedNotEqual(string(2)) + T.MaskedNotEqual(fmt.Sprint(2)) assert.False(T.mask[2] && !(T.mask[0])) T.ResetMask() - T.MaskedInside(string(1), string(22)) + T.MaskedInside(fmt.Sprint(1), fmt.Sprint(22)) assert.True(!T.mask[0] && !T.mask[23] && T.mask[1] && T.mask[22]) T.ResetMask() - T.MaskedOutside(string(1), string(22)) + T.MaskedOutside(fmt.Sprint(1), fmt.Sprint(22)) assert.True(T.mask[0] && T.mask[23] && !T.mask[1] && !T.mask[22]) T.ResetMask() for i := 0; i < 5; i++ { - T.MaskedEqual(string(i * 10)) + T.MaskedEqual(fmt.Sprint(i * 10)) } it := IteratorFromDense(T) @@ -1985,26 +1986,26 @@ func TestDense_MaskedGreater_Str(t *testing.T) { assert.False(T.IsMasked()) data := T.Strings() for i := range data { - data[i] = string(i) + data[i] = fmt.Sprint(i) } - T.MaskedEqual(string(0)) + T.MaskedEqual(fmt.Sprint(0)) assert.True(T.IsMasked()) - T.MaskedEqual(string(1)) + T.MaskedEqual(fmt.Sprint(1)) assert.True(T.mask[0] && T.mask[1]) - T.MaskedNotEqual(string(2)) + T.MaskedNotEqual(fmt.Sprint(2)) assert.False(T.mask[2] && !(T.mask[0])) T.ResetMask() - T.MaskedInside(string(1), string(22)) + T.MaskedInside(fmt.Sprint(1), fmt.Sprint(22)) assert.True(!T.mask[0] && !T.mask[23] && T.mask[1] && T.mask[22]) T.ResetMask() - T.MaskedOutside(string(1), string(22)) + T.MaskedOutside(fmt.Sprint(1), fmt.Sprint(22)) assert.True(T.mask[0] && T.mask[23] && !T.mask[1] && !T.mask[22]) T.ResetMask() for i := 0; i < 5; i++ { - T.MaskedEqual(string(i * 10)) + T.MaskedEqual(fmt.Sprint(i * 10)) } it := IteratorFromDense(T) @@ -2625,26 +2626,26 @@ func TestDense_MaskedGreaterEqual_Str(t *testing.T) { assert.False(T.IsMasked()) data := T.Strings() for i := range data { - data[i] = string(i) + data[i] = fmt.Sprint(i) } - T.MaskedEqual(string(0)) + T.MaskedEqual(fmt.Sprint(0)) assert.True(T.IsMasked()) - T.MaskedEqual(string(1)) + T.MaskedEqual(fmt.Sprint(1)) assert.True(T.mask[0] && T.mask[1]) - T.MaskedNotEqual(string(2)) + T.MaskedNotEqual(fmt.Sprint(2)) assert.False(T.mask[2] && !(T.mask[0])) T.ResetMask() - T.MaskedInside(string(1), string(22)) + T.MaskedInside(fmt.Sprint(1), fmt.Sprint(22)) assert.True(!T.mask[0] && !T.mask[23] && T.mask[1] && T.mask[22]) T.ResetMask() - T.MaskedOutside(string(1), string(22)) + T.MaskedOutside(fmt.Sprint(1), fmt.Sprint(22)) assert.True(T.mask[0] && T.mask[23] && !T.mask[1] && !T.mask[22]) T.ResetMask() for i := 0; i < 5; i++ { - T.MaskedEqual(string(i * 10)) + T.MaskedEqual(fmt.Sprint(i * 10)) } it := IteratorFromDense(T) @@ -3265,26 +3266,26 @@ func TestDense_MaskedLess_Str(t *testing.T) { assert.False(T.IsMasked()) data := T.Strings() for i := range data { - data[i] = string(i) + data[i] = fmt.Sprint(i) } - T.MaskedEqual(string(0)) + T.MaskedEqual(fmt.Sprint(0)) assert.True(T.IsMasked()) - T.MaskedEqual(string(1)) + T.MaskedEqual(fmt.Sprint(1)) assert.True(T.mask[0] && T.mask[1]) - T.MaskedNotEqual(string(2)) + T.MaskedNotEqual(fmt.Sprint(2)) assert.False(T.mask[2] && !(T.mask[0])) T.ResetMask() - T.MaskedInside(string(1), string(22)) + T.MaskedInside(fmt.Sprint(1), fmt.Sprint(22)) assert.True(!T.mask[0] && !T.mask[23] && T.mask[1] && T.mask[22]) T.ResetMask() - T.MaskedOutside(string(1), string(22)) + T.MaskedOutside(fmt.Sprint(1), fmt.Sprint(22)) assert.True(T.mask[0] && T.mask[23] && !T.mask[1] && !T.mask[22]) T.ResetMask() for i := 0; i < 5; i++ { - T.MaskedEqual(string(i * 10)) + T.MaskedEqual(fmt.Sprint(i * 10)) } it := IteratorFromDense(T) @@ -3905,26 +3906,26 @@ func TestDense_MaskedLessEqual_Str(t *testing.T) { assert.False(T.IsMasked()) data := T.Strings() for i := range data { - data[i] = string(i) + data[i] = fmt.Sprint(i) } - T.MaskedEqual(string(0)) + T.MaskedEqual(fmt.Sprint(0)) assert.True(T.IsMasked()) - T.MaskedEqual(string(1)) + T.MaskedEqual(fmt.Sprint(1)) assert.True(T.mask[0] && T.mask[1]) - T.MaskedNotEqual(string(2)) + T.MaskedNotEqual(fmt.Sprint(2)) assert.False(T.mask[2] && !(T.mask[0])) T.ResetMask() - T.MaskedInside(string(1), string(22)) + T.MaskedInside(fmt.Sprint(1), fmt.Sprint(22)) assert.True(!T.mask[0] && !T.mask[23] && T.mask[1] && T.mask[22]) T.ResetMask() - T.MaskedOutside(string(1), string(22)) + T.MaskedOutside(fmt.Sprint(1), fmt.Sprint(22)) assert.True(T.mask[0] && T.mask[23] && !T.mask[1] && !T.mask[22]) T.ResetMask() for i := 0; i < 5; i++ { - T.MaskedEqual(string(i * 10)) + T.MaskedEqual(fmt.Sprint(i * 10)) } it := IteratorFromDense(T) @@ -4545,26 +4546,26 @@ func TestDense_MaskedInside_Str(t *testing.T) { assert.False(T.IsMasked()) data := T.Strings() for i := range data { - data[i] = string(i) + data[i] = fmt.Sprint(i) } - T.MaskedEqual(string(0)) + T.MaskedEqual(fmt.Sprint(0)) assert.True(T.IsMasked()) - T.MaskedEqual(string(1)) + T.MaskedEqual(fmt.Sprint(1)) assert.True(T.mask[0] && T.mask[1]) - T.MaskedNotEqual(string(2)) + T.MaskedNotEqual(fmt.Sprint(2)) assert.False(T.mask[2] && !(T.mask[0])) T.ResetMask() - T.MaskedInside(string(1), string(22)) + T.MaskedInside(fmt.Sprint(1), fmt.Sprint(22)) assert.True(!T.mask[0] && !T.mask[23] && T.mask[1] && T.mask[22]) T.ResetMask() - T.MaskedOutside(string(1), string(22)) + T.MaskedOutside(fmt.Sprint(1), fmt.Sprint(22)) assert.True(T.mask[0] && T.mask[23] && !T.mask[1] && !T.mask[22]) T.ResetMask() for i := 0; i < 5; i++ { - T.MaskedEqual(string(i * 10)) + T.MaskedEqual(fmt.Sprint(i * 10)) } it := IteratorFromDense(T) @@ -5185,26 +5186,26 @@ func TestDense_MaskedOutside_Str(t *testing.T) { assert.False(T.IsMasked()) data := T.Strings() for i := range data { - data[i] = string(i) + data[i] = fmt.Sprint(i) } - T.MaskedEqual(string(0)) + T.MaskedEqual(fmt.Sprint(0)) assert.True(T.IsMasked()) - T.MaskedEqual(string(1)) + T.MaskedEqual(fmt.Sprint(1)) assert.True(T.mask[0] && T.mask[1]) - T.MaskedNotEqual(string(2)) + T.MaskedNotEqual(fmt.Sprint(2)) assert.False(T.mask[2] && !(T.mask[0])) T.ResetMask() - T.MaskedInside(string(1), string(22)) + T.MaskedInside(fmt.Sprint(1), fmt.Sprint(22)) assert.True(!T.mask[0] && !T.mask[23] && T.mask[1] && T.mask[22]) T.ResetMask() - T.MaskedOutside(string(1), string(22)) + T.MaskedOutside(fmt.Sprint(1), fmt.Sprint(22)) assert.True(T.mask[0] && T.mask[23] && !T.mask[1] && !T.mask[22]) T.ResetMask() for i := 0; i < 5; i++ { - T.MaskedEqual(string(i * 10)) + T.MaskedEqual(fmt.Sprint(i * 10)) } it := IteratorFromDense(T) diff --git a/dense_matop.go b/dense_matop.go index 46e8a55..3a9f005 100644 --- a/dense_matop.go +++ b/dense_matop.go @@ -1,21 +1,22 @@ package tensor -import "github.com/pkg/errors" +import ( + "github.com/pkg/errors" +) // T performs a thunked transpose. It doesn't actually do anything, except store extra information about the post-transposed shapes and strides // Usually this is more than enough, as BLAS will handle the rest of the transpose func (t *Dense) T(axes ...int) (err error) { - var transform *AP + var transform AP if transform, axes, err = t.AP.T(axes...); err != nil { return handleNoOp(err) } // is there any old transposes that need to be done first? // this is important, because any old transposes for dim >=3 are merely permutations of the strides - if t.old != nil { + if !t.old.IsZero() { if t.IsVector() { // the transform that was calculated was a waste of time - return it to the pool then untranspose - ReturnAP(transform) t.UT() return } @@ -31,7 +32,6 @@ func (t *Dense) T(axes ...int) (err error) { // if it is reversed, well, we just restore the backed up one if isReversed { - ReturnAP(transform) t.UT() return } @@ -58,31 +58,30 @@ func (t *Dense) T(axes ...int) (err error) { // // Nothing will happen if there was no previous transpose func (t *Dense) UT() { - if t.old != nil { - ReturnAP(t.AP) + if !t.old.IsZero() { ReturnInts(t.transposeWith) t.AP = t.old - t.old = nil + t.old.zeroOnly() t.transposeWith = nil } } // SafeT is exactly like T(), except it returns a new *Dense. The data is also copied over, unmoved. func (t *Dense) SafeT(axes ...int) (retVal *Dense, err error) { - var transform *AP + var transform AP if transform, axes, err = t.AP.T(axes...); err != nil { if err = handleNoOp(err); err != nil { return } } - retVal = recycledDense(t.t, Shape{t.len()}) + retVal = recycledDense(t.t, Shape{t.len()}, WithEngine(t.e)) copyDense(retVal, t) retVal.e = t.e retVal.oe = t.oe retVal.AP = transform - retVal.old = t.AP.Clone() + t.AP.CloneTo(&retVal.old) retVal.transposeWith = axes return @@ -132,7 +131,9 @@ func (t *Dense) SetAt(v interface{}, coords ...int) error { return errors.Errorf(inaccessibleData, t) } - if len(coords) != t.Dims() { + switch { + case t.IsScalar() && len(coords) == 1: + case len(coords) != t.Dims(): return errors.Errorf(dimMismatch, t.Dims(), len(coords)) } @@ -196,7 +197,17 @@ func (t *Dense) CopyTo(other *Dense) error { } // TODO: use copyDenseIter - return errors.Errorf(methodNYI, "CopyTo", "views") + return nyierr(methodNYI, "views") +} + +// Narrow narrows the tensor. +func (t *Dense) Narrow(dim, start, length int) (View, error) { + dim = resolveAxis(dim, t.Dims()) + + slices := make([]Slice, MinInt(dim+1, t.Dims())) + slices[dim] = S(start, start+length, 1) + + return t.Slice(slices...) } // Slice performs slicing on the *Dense Tensor. It returns a view which shares the same underlying memory as the original *Dense. @@ -209,7 +220,7 @@ func (t *Dense) CopyTo(other *Dense) error { // // The method treats as equivalent to a colon slice. T.Slice(nil) is equivalent to T[:] in Numpy syntax func (t *Dense) Slice(slices ...Slice) (retVal View, err error) { - var newAP *AP + var newAP AP var ndStart, ndEnd int if newAP, ndStart, ndEnd, err = t.AP.S(t.len(), slices...); err != nil { @@ -229,23 +240,42 @@ func (t *Dense) Slice(slices ...Slice) (retVal View, err error) { view.mask = t.mask[ndStart:ndEnd] } - return view, err + return DenseView{view}, err } // SliceInto is a convenience method. It does NOT copy the values - it simply updates the AP of the view. // The underlying data is the same. // This method will override ALL the metadata in view. -func (t *Dense) SliceInto(view *Dense, slices ...Slice) (retVal View, err error) { - var newAP *AP +func (t *Dense) SliceInto(view Tensor, slices ...Slice) (retVal Tensor, err error) { + switch view := view.(type) { + case nil: + return t.Slice(slices...) + case DenseView: + v := view.Dense + if v, err = t.sliceIntoDense(v, slices...); err != nil { + return nil, err + } + return DenseView{v}, nil + + case *Dense: + if view, err = t.sliceIntoDense(view, slices...); err != nil { + return nil, err + } + return DenseView{view}, nil + default: + return nil, nyierr(typeNYI, view) + } +} + +func (t *Dense) sliceIntoDense(view *Dense, slices ...Slice) (retVal *Dense, err error) { + var newAP AP var ndStart, ndEnd int if newAP, ndStart, ndEnd, err = t.AP.S(t.len(), slices...); err != nil { - return + return nil, err } - ReturnAP(view.AP) - view.AP = nil - view.array.v = nil // reset + view.AP.zero() view.t = t.t view.e = t.e @@ -258,9 +288,7 @@ func (t *Dense) SliceInto(view *Dense, slices ...Slice) (retVal View, err error) if t.IsMasked() { view.mask = t.mask[ndStart:ndEnd] } - - return view, err - + return view, nil } // RollAxis rolls the axis backwards until it lies in the given position. @@ -314,6 +342,7 @@ func (t *Dense) RollAxis(axis, start int, safe bool) (retVal *Dense, err error) func (t *Dense) transposeIndex(i int, transposePat, strides []int) int { oldCoord, err := Itol(i, t.oshape(), t.ostrides()) if err != nil { + err = errors.Wrapf(err, "transposeIndex ItoL failure. i %d original shape %v. original strides %v", i, t.oshape(), t.ostrides()) panic(err) } diff --git a/dense_matop_memmove.go b/dense_matop_memmove.go index 05033ef..9d63082 100644 --- a/dense_matop_memmove.go +++ b/dense_matop_memmove.go @@ -9,7 +9,7 @@ import "github.com/pkg/errors" // https://en.wikipedia.org/wiki/In-place_matrix_transposition func (t *Dense) Transpose() error { // if there is no oldinfo, that means the current info is the latest, and not the transpose - if t.old == nil { + if t.old.IsZero() { return nil } @@ -18,8 +18,7 @@ func (t *Dense) Transpose() error { } defer func() { - ReturnAP(t.old) - t.old = nil + t.old.zero() t.transposeWith = nil }() @@ -27,10 +26,10 @@ func (t *Dense) Transpose() error { // important! because the strides would have changed once the underlying data changed var expStrides []int - if t.AP.o.isColMajor() { - expStrides = expShape.calcStridesColMajor() + if t.AP.o.IsColMajor() { + expStrides = CalcStridesColMajor(expShape) } else { - expStrides = expShape.calcStrides() + expStrides = CalcStrides(expShape) } defer ReturnInts(expStrides) defer func() { @@ -44,13 +43,14 @@ func (t *Dense) Transpose() error { } // actually move data - var e Engine = t.e + e := t.Engine() + ctx := ctxFromEngine(e) transposer, ok := e.(Transposer) if !ok { return errors.Errorf("Engine does not support Transpose()") } - return transposer.Transpose(t, expStrides) + return transposer.Transpose(ctx, t, expStrides) } // Repeat is like Numpy's repeat. It repeats the elements of an array. @@ -58,9 +58,10 @@ func (t *Dense) Transpose() error { // Just like NumPy, the repeats param is broadcasted to fit the size of the given axis. func (t *Dense) Repeat(axis int, repeats ...int) (retVal Tensor, err error) { e := t.Engine() + ctx := ctxFromEngine(e) if rp, ok := e.(Repeater); ok { - return rp.Repeat(t, axis, repeats...) + return rp.Repeat(ctx, t, axis, repeats...) } return nil, errors.New("Engine does not support Repeat") } @@ -68,11 +69,12 @@ func (t *Dense) Repeat(axis int, repeats ...int) (retVal Tensor, err error) { // Concat concatenates the other tensors along the given axis. It is like Numpy's concatenate() function. func (t *Dense) Concat(axis int, Ts ...*Dense) (retVal *Dense, err error) { e := t.Engine() + ctx := ctxFromEngine(e) if c, ok := e.(Concater); ok { var ret Tensor others := densesToTensors(Ts) - if ret, err = c.Concat(t, axis, others...); err != nil { + if ret, err = c.Concat(ctx, t, axis, others...); err != nil { return nil, errors.Wrapf(err, opFail, "Concat") } return ret.(*Dense), nil @@ -128,8 +130,10 @@ func (t *Dense) Stack(axis int, others ...*Dense) (retVal *Dense, err error) { } func (t *Dense) stackDense(axis int, others ...DenseTensor) (retVal DenseTensor, err error) { + e := t.Engine() + ctx := ctxFromEngine(e) if ds, ok := t.Engine().(DenseStacker); ok { - return ds.StackDense(t, axis, others...) + return ds.StackDense(ctx, t, axis, others...) } return nil, errors.Errorf("Engine does not support DenseStacker") } diff --git a/dense_matop_test.go b/dense_matop_test.go index 51ee94a..2e3c9bb 100644 --- a/dense_matop_test.go +++ b/dense_matop_test.go @@ -1,9 +1,11 @@ package tensor import ( + "fmt" "testing" "github.com/stretchr/testify/assert" + "gorgonia.org/dtype" "gorgonia.org/vecf64" ) @@ -41,7 +43,7 @@ func cloneArray(a interface{}) interface{} { return nil } -func castToDt(val float64, dt Dtype) interface{} { +func castToDt(val float64, dt dtype.Dtype) interface{} { switch dt { case Bool: return false @@ -135,10 +137,10 @@ var transposeTests = []struct { correctData interface{} }{ {"c.T()", Shape{4, 1}, nil, []float64{0, 1, 2, 3}, - Shape{1, 4}, []int{1}, []int{1}, []float64{0, 1, 2, 3}}, + Shape{1, 4}, []int{1, 1}, []int{4, 1}, []float64{0, 1, 2, 3}}, {"r.T()", Shape{1, 4}, nil, []float32{0, 1, 2, 3}, - Shape{4, 1}, []int{1}, []int{1}, []float32{0, 1, 2, 3}}, + Shape{4, 1}, []int{1, 1}, []int{1, 1}, []float32{0, 1, 2, 3}}, {"v.T()", Shape{4}, nil, []int{0, 1, 2, 3}, Shape{4}, []int{1}, []int{1}, []int{0, 1, 2, 3}}, @@ -216,10 +218,10 @@ func TestDense_Transpose(t *testing.T) { } assert.True(tts.correctShape.Eq(T.Shape()), "Transpose %v Expected shape: %v. Got %v", tts.name, tts.correctShape, T.Shape()) - assert.Equal(tts.correctStrides, T.Strides()) + assert.Equal(tts.correctStrides, T.Strides(), "Transpose %v. Expected stride: %v. Got %v", tts.name, tts.correctStrides, T.Strides()) T.Transpose() assert.True(tts.correctShape.Eq(T.Shape()), "Transpose %v Expected shape: %v. Got %v", tts.name, tts.correctShape, T.Shape()) - assert.Equal(tts.correctStrides2, T.Strides(), "Transpose %v - Wrong strides", tts.name) + assert.Equal(tts.correctStrides2, T.Strides(), "Transpose2 %v - Expected stride %v. Got %v", tts.name, tts.correctStrides2, T.Strides()) assert.Equal(tts.correctData, T.Data(), "Transpose %v", tts.name) } @@ -236,7 +238,7 @@ func TestDense_Transpose(t *testing.T) { t.Errorf("Stacked .T() #1 for vector. Error: %v", err) goto matrev } - assert.Nil(T.old) + assert.True(T.old.IsZero()) assert.Nil(T.transposeWith) assert.True(T.IsColVec()) @@ -251,7 +253,7 @@ matrev: t.Errorf("Stacked .T() #2 for matrix reverse. Error: %v", err) goto matnorev } - assert.Nil(T.old) + assert.True(T.old.IsZero()) assert.Nil(T.transposeWith) assert.True(Shape{2, 3}.Eq(T.Shape())) @@ -278,16 +280,16 @@ func TestTUT(t *testing.T) { T = New(Of(Float64), WithShape(2, 3, 4)) T.T() T.UT() - assert.Nil(T.old) + assert.True(T.old.IsZero()) assert.Nil(T.transposeWith) T.T(2, 0, 1) T.UT() - assert.Nil(T.old) + assert.True(T.old.IsZero()) assert.Nil(T.transposeWith) } -var repeatTests = []struct { +type repeatTest struct { name string tensor *Dense ne bool // should assert tensor not equal @@ -297,7 +299,9 @@ var repeatTests = []struct { correct interface{} shape Shape err bool -}{ +} + +var repeatTests = []repeatTest{ {"Scalar Repeat on axis 0", New(FromScalar(true)), true, 0, []int{3}, []bool{true, true, true}, @@ -435,6 +439,49 @@ func TestDense_Repeat(t *testing.T) { } } +func TestDense_Repeat_Slow(t *testing.T) { + rt2 := make([]repeatTest, len(repeatTests)) + for i, rt := range repeatTests { + rt2[i] = repeatTest{ + name: rt.name, + ne: rt.ne, + axis: rt.axis, + repeats: rt.repeats, + correct: rt.correct, + shape: rt.shape, + err: rt.err, + tensor: rt.tensor.Clone().(*Dense), + } + } + for i := range rt2 { + maskLen := rt2[i].tensor.len() + mask := make([]bool, maskLen) + rt2[i].tensor.mask = mask + } + + assert := assert.New(t) + + for i, test := range rt2 { + T, err := test.tensor.Repeat(test.axis, test.repeats...) + if checkErr(t, test.err, err, "Repeat", i) { + continue + } + + var D DenseTensor + if D, err = getDenseTensor(T); err != nil { + t.Errorf("Expected Repeat to return a *Dense. got %v of %T instead", T, T) + continue + } + + if test.ne { + assert.NotEqual(test.tensor, D, test.name) + } + + assert.Equal(test.correct, D.Data(), test.name) + assert.Equal(test.shape, D.Shape(), test.name) + } +} + func TestDense_CopyTo(t *testing.T) { assert := assert.New(t) var T, T2 *Dense @@ -458,7 +505,7 @@ func TestDense_CopyTo(t *testing.T) { T = New(Of(Byte), WithShape(3, 3)) T2 = New(Of(Byte), WithShape(2, 2)) T3, _ = T.Slice(makeRS(0, 2), makeRS(0, 2)) // T[0:2, 0:2], shape == (2,2) - if err = T2.CopyTo(T3.(*Dense)); err != nil { + if err = T2.CopyTo(MustGetDense(T3)); err != nil { t.Log(err) // for now it's a not yet implemented error. TODO: FIX THIS } @@ -486,6 +533,11 @@ var denseSliceTests = []struct { correctStride []int correctData interface{} }{ + // scalar-equiv vector (issue 102) + {"a[0], a is scalar-equiv", []float64{2}, + Shape{1}, []Slice{ss(0)}, ScalarShape(), nil, 2.0}, + + // vector {"a[0]", []bool{true, true, false, false, false}, Shape{5}, []Slice{ss(0)}, ScalarShape(), nil, true}, {"a[0:2]", Range(Byte, 0, 5), Shape{5}, []Slice{makeRS(0, 2)}, Shape{2}, []int{1}, []byte{0, 1}}, @@ -493,16 +545,16 @@ var denseSliceTests = []struct { // colvec {"c[0]", Range(Int64, 0, 5), Shape{5, 1}, []Slice{ss(0)}, ScalarShape(), nil, int64(0)}, - {"c[0:2]", Range(Float32, 0, 5), Shape{5, 1}, []Slice{makeRS(0, 2)}, Shape{2, 1}, []int{1}, []float32{0, 1}}, - {"c[1:5:2]", Range(Float64, 0, 5), Shape{5, 1}, []Slice{makeRS(0, 5, 2)}, Shape{2, 1}, []int{2}, []float64{0, 1, 2, 3, 4}}, + {"c[0:2]", Range(Float32, 0, 5), Shape{5, 1}, []Slice{makeRS(0, 2)}, Shape{2, 1}, []int{1, 1}, []float32{0, 1}}, + {"c[1:5:2]", Range(Float64, 0, 5), Shape{5, 1}, []Slice{makeRS(0, 5, 2)}, Shape{2, 1}, []int{2, 1}, []float64{0, 1, 2, 3, 4}}, // // rowvec {"r[0]", Range(Float64, 0, 5), Shape{1, 5}, []Slice{ss(0)}, Shape{1, 5}, []int{1}, []float64{0, 1, 2, 3, 4}}, {"r[0:2]", Range(Float64, 0, 5), Shape{1, 5}, []Slice{makeRS(0, 2)}, Shape{1, 5}, []int{1}, []float64{0, 1, 2, 3, 4}}, {"r[0:5:2]", Range(Float64, 0, 5), Shape{1, 5}, []Slice{makeRS(0, 5, 2)}, Shape{1, 5}, []int{1}, []float64{0, 1, 2, 3, 4}}, {"r[:, 0]", Range(Float64, 0, 5), Shape{1, 5}, []Slice{nil, ss(0)}, ScalarShape(), nil, float64(0)}, - {"r[:, 0:2]", Range(Float64, 0, 5), Shape{1, 5}, []Slice{nil, makeRS(0, 2)}, Shape{1, 2}, []int{1}, []float64{0, 1}}, - {"r[:, 1:5:2]", Range(Float64, 0, 5), Shape{1, 5}, []Slice{nil, makeRS(1, 5, 2)}, Shape{1, 2}, []int{2}, []float64{1, 2, 3, 4}}, + {"r[:, 0:2]", Range(Float64, 0, 5), Shape{1, 5}, []Slice{nil, makeRS(0, 2)}, Shape{1, 2}, []int{5, 1}, []float64{0, 1}}, + {"r[:, 1:5:2]", Range(Float64, 0, 5), Shape{1, 5}, []Slice{nil, makeRS(1, 5, 2)}, Shape{1, 2}, []int{5, 2}, []float64{1, 2, 3, 4}}, // // matrix {"A[0]", Range(Float64, 0, 6), Shape{2, 3}, []Slice{ss(0)}, Shape{1, 3}, []int{1}, Range(Float64, 0, 3)}, @@ -513,6 +565,25 @@ var denseSliceTests = []struct { {"A[:, 0]", Range(Float64, 0, 20), Shape{4, 5}, []Slice{nil, ss(0)}, Shape{4, 1}, []int{5}, Range(Float64, 0, 16)}, {"A[:, 1:5]", Range(Float64, 0, 20), Shape{4, 5}, []Slice{nil, makeRS(1, 5)}, Shape{4, 4}, []int{5, 1}, Range(Float64, 1, 20)}, {"A[:, 1:5:2]", Range(Float64, 0, 20), Shape{4, 5}, []Slice{nil, makeRS(1, 5, 2)}, Shape{4, 2}, []int{5, 2}, Range(Float64, 1, 20)}, + + // 3tensor with leading and trailing 1s + + {"3T1[0]", Range(Float64, 0, 9), Shape{1, 9, 1}, []Slice{ss(0)}, Shape{9, 1}, []int{1, 1}, Range(Float64, 0, 9)}, + {"3T1[nil, 0:2]", Range(Float64, 0, 9), Shape{1, 9, 1}, []Slice{nil, makeRS(0, 2)}, Shape{1, 2, 1}, []int{9, 1, 1}, Range(Float64, 0, 2)}, + {"3T1[nil, 0:5:3]", Range(Float64, 0, 9), Shape{1, 9, 1}, []Slice{nil, makeRS(0, 5, 3)}, Shape{1, 2, 1}, []int{9, 3, 1}, Range(Float64, 0, 5)}, + {"3T1[nil, 1:5:3]", Range(Float64, 0, 9), Shape{1, 9, 1}, []Slice{nil, makeRS(1, 5, 3)}, Shape{1, 2, 1}, []int{9, 3, 1}, Range(Float64, 1, 5)}, + {"3T1[nil, 1:9:3]", Range(Float64, 0, 9), Shape{1, 9, 1}, []Slice{nil, makeRS(1, 9, 3)}, Shape{1, 3, 1}, []int{9, 3, 1}, Range(Float64, 1, 9)}, + + // 3tensor + {"3T[0]", Range(Float64, 0, 36), Shape{2, 9, 2}, []Slice{ss(0)}, Shape{9, 2}, []int{2, 1}, Range(Float64, 0, 18)}, + {"3T[1]", Range(Float64, 0, 36), Shape{2, 9, 2}, []Slice{ss(1)}, Shape{9, 2}, []int{2, 1}, Range(Float64, 18, 36)}, + {"3T[1, 2]", Range(Float64, 0, 36), Shape{2, 9, 2}, []Slice{ss(1), ss(2)}, Shape{2}, []int{1}, Range(Float64, 22, 24)}, + {"3T[1, 2:4]", Range(Float64, 0, 36), Shape{2, 9, 2}, []Slice{ss(1), makeRS(2, 4)}, Shape{2, 2}, []int{2, 1}, Range(Float64, 22, 26)}, + {"3T[1, 2:8:2]", Range(Float64, 0, 36), Shape{2, 9, 2}, []Slice{ss(1), makeRS(2, 8, 2)}, Shape{3, 2}, []int{4, 1}, Range(Float64, 22, 34)}, + {"3T[1, 2:8:3]", Range(Float64, 0, 36), Shape{2, 9, 2}, []Slice{ss(1), makeRS(2, 8, 3)}, Shape{2, 2}, []int{6, 1}, Range(Float64, 22, 34)}, + {"3T[1, 2:9:2]", Range(Float64, 0, 126), Shape{2, 9, 7}, []Slice{ss(1), makeRS(2, 9, 2)}, Shape{4, 7}, []int{14, 1}, Range(Float64, 77, 126)}, + {"3T[1, 2:9:2, 1]", Range(Float64, 0, 126), Shape{2, 9, 7}, []Slice{ss(1), makeRS(2, 9, 2), ss(1)}, Shape{4}, []int{14}, Range(Float64, 78, 121)}, // should this be a colvec? + {"3T[1, 2:9:2, 1:4:2]", Range(Float64, 0, 126), Shape{2, 9, 7}, []Slice{ss(1), makeRS(2, 9, 2), makeRS(1, 4, 2)}, Shape{4, 2}, []int{14, 2}, Range(Float64, 78, 123)}, } func TestDense_Slice(t *testing.T) { @@ -540,10 +611,12 @@ func TestDense_Slice(t *testing.T) { assert.True(Shape{2}.Eq(V.Shape())) assert.Equal([]int{3}, V.Strides()) assert.Equal([]float32{0, 1, 2, 3}, V.Data()) - assert.Nil(V.(*Dense).old) + assert.True(MustGetDense(V).old.IsZero()) // slice a sliced + t.Logf("%v", V) V, err = V.Slice(makeRS(1, 2)) + t.Logf("%v", V) assert.True(ScalarShape().Eq(V.Shape())) assert.Equal(float32(3), V.Data()) @@ -566,7 +639,86 @@ func TestDense_Slice(t *testing.T) { if err == nil { t.Error("Expected a IndexError") } +} + +func TestDense_Narrow(t *testing.T) { + testCases := []struct { + x *Dense + dim, start, length int + expected *Dense + }{ + { + x: New( + WithShape(3), + WithBacking([]int{1, 2, 3}), + ), + dim: 0, + start: 1, + length: 1, + expected: New( + WithShape(), + WithBacking([]int{2}), + ), + }, + { + x: New( + WithShape(3, 3), + WithBacking([]int{1, 2, 3, 4, 5, 6, 7, 8, 9}), + ), + dim: 0, + start: 0, + length: 2, + expected: New( + WithShape(2, 3), + WithBacking([]int{1, 2, 3, 4, 5, 6}), + ), + }, + { + x: New( + WithShape(3, 3), + WithBacking([]int{1, 2, 3, 4, 5, 6, 7, 8, 9}), + ), + dim: 1, + start: 1, + length: 2, + expected: New( + WithShape(3, 2), + WithBacking([]int{2, 3, 5, 6, 8, 9}), + ), + }, + { + x: New( + WithShape(3, 3), + WithBacking([]int{1, 2, 3, 4, 5, 6, 7, 8, 9}), + ), + dim: 1, + start: 0, + length: 1, + expected: New( + WithShape(3), + WithBacking([]int{1, 4, 7}), + ), + }, + } + + for i, tC := range testCases { + t.Run(fmt.Sprintf("Example #%d narrow(%v,%d,%d,%v)", i+1, tC.x.Shape(), tC.dim, tC.start, tC.length), func(t *testing.T) { + c := assert.New(t) + // t.Logf("X:\n%v", tC.x) + + y, err := tC.x.Narrow(tC.dim, tC.start, tC.length) + c.NoError(err) + // t.Logf("y:\n%v", y) + yMat := y.Materialize() + c.Equal(tC.expected.Shape(), yMat.Shape()) + c.Equal(tC.expected.Data(), yMat.Data()) + + // err = y.Memset(1024) + // c.NoError(err) + // t.Logf("After Memset\nY: %v\nX:\n%v", y, tC.x) + }) + } } func TestDense_SliceInto(t *testing.T) { @@ -623,49 +775,61 @@ func TestDense_RollAxis(t *testing.T) { } var concatTests = []struct { - name string - dt Dtype - a interface{} - shape Shape - axis int + name string + dt dtype.Dtype + a interface{} + b interface{} + shape Shape + shapeB Shape + axis int correctShape Shape correctData interface{} }{ // Float64 - {"vector", Float64, nil, Shape{2}, 0, Shape{4}, []float64{0, 1, 0, 1}}, - {"matrix; axis 0 ", Float64, nil, Shape{2, 2}, 0, Shape{4, 2}, []float64{0, 1, 2, 3, 0, 1, 2, 3}}, - {"matrix; axis 1 ", Float64, nil, Shape{2, 2}, 1, Shape{2, 4}, []float64{0, 1, 0, 1, 2, 3, 2, 3}}, + {"vector", Float64, nil, nil, Shape{2}, nil, 0, Shape{4}, []float64{0, 1, 0, 1}}, + {"matrix; axis 0 ", Float64, nil, nil, Shape{2, 2}, nil, 0, Shape{4, 2}, []float64{0, 1, 2, 3, 0, 1, 2, 3}}, + {"matrix; axis 1 ", Float64, nil, nil, Shape{2, 2}, nil, 1, Shape{2, 4}, []float64{0, 1, 0, 1, 2, 3, 2, 3}}, // Float32 - {"vector", Float32, nil, Shape{2}, 0, Shape{4}, []float32{0, 1, 0, 1}}, - {"matrix; axis 0 ", Float32, nil, Shape{2, 2}, 0, Shape{4, 2}, []float32{0, 1, 2, 3, 0, 1, 2, 3}}, - {"matrix; axis 1 ", Float32, nil, Shape{2, 2}, 1, Shape{2, 4}, []float32{0, 1, 0, 1, 2, 3, 2, 3}}, + {"vector", Float32, nil, nil, Shape{2}, nil, 0, Shape{4}, []float32{0, 1, 0, 1}}, + {"matrix; axis 0 ", Float32, nil, nil, Shape{2, 2}, nil, 0, Shape{4, 2}, []float32{0, 1, 2, 3, 0, 1, 2, 3}}, + {"matrix; axis 1 ", Float32, nil, nil, Shape{2, 2}, nil, 1, Shape{2, 4}, []float32{0, 1, 0, 1, 2, 3, 2, 3}}, // Int - {"vector", Int, nil, Shape{2}, 0, Shape{4}, []int{0, 1, 0, 1}}, - {"matrix; axis 0 ", Int, nil, Shape{2, 2}, 0, Shape{4, 2}, []int{0, 1, 2, 3, 0, 1, 2, 3}}, - {"matrix; axis 1 ", Int, nil, Shape{2, 2}, 1, Shape{2, 4}, []int{0, 1, 0, 1, 2, 3, 2, 3}}, + {"vector", Int, nil, nil, Shape{2}, nil, 0, Shape{4}, []int{0, 1, 0, 1}}, + {"matrix; axis 0 ", Int, nil, nil, Shape{2, 2}, nil, 0, Shape{4, 2}, []int{0, 1, 2, 3, 0, 1, 2, 3}}, + {"matrix; axis 1 ", Int, nil, nil, Shape{2, 2}, nil, 1, Shape{2, 4}, []int{0, 1, 0, 1, 2, 3, 2, 3}}, // Int64 - {"vector", Int64, nil, Shape{2}, 0, Shape{4}, []int64{0, 1, 0, 1}}, - {"matrix; axis 0 ", Int64, nil, Shape{2, 2}, 0, Shape{4, 2}, []int64{0, 1, 2, 3, 0, 1, 2, 3}}, - {"matrix; axis 1 ", Int64, nil, Shape{2, 2}, 1, Shape{2, 4}, []int64{0, 1, 0, 1, 2, 3, 2, 3}}, + {"vector", Int64, nil, nil, Shape{2}, nil, 0, Shape{4}, []int64{0, 1, 0, 1}}, + {"matrix; axis 0 ", Int64, nil, nil, Shape{2, 2}, nil, 0, Shape{4, 2}, []int64{0, 1, 2, 3, 0, 1, 2, 3}}, + {"matrix; axis 1 ", Int64, nil, nil, Shape{2, 2}, nil, 1, Shape{2, 4}, []int64{0, 1, 0, 1, 2, 3, 2, 3}}, // Int32 - {"vector", Int32, nil, Shape{2}, 0, Shape{4}, []int32{0, 1, 0, 1}}, - {"matrix; axis 0 ", Int32, nil, Shape{2, 2}, 0, Shape{4, 2}, []int32{0, 1, 2, 3, 0, 1, 2, 3}}, - {"matrix; axis 1 ", Int32, nil, Shape{2, 2}, 1, Shape{2, 4}, []int32{0, 1, 0, 1, 2, 3, 2, 3}}, + {"vector", Int32, nil, nil, Shape{2}, nil, 0, Shape{4}, []int32{0, 1, 0, 1}}, + {"matrix; axis 0 ", Int32, nil, nil, Shape{2, 2}, nil, 0, Shape{4, 2}, []int32{0, 1, 2, 3, 0, 1, 2, 3}}, + {"matrix; axis 1 ", Int32, nil, nil, Shape{2, 2}, nil, 1, Shape{2, 4}, []int32{0, 1, 0, 1, 2, 3, 2, 3}}, // Byte - {"vector", Byte, nil, Shape{2}, 0, Shape{4}, []byte{0, 1, 0, 1}}, - {"matrix; axis 0 ", Byte, nil, Shape{2, 2}, 0, Shape{4, 2}, []byte{0, 1, 2, 3, 0, 1, 2, 3}}, - {"matrix; axis 1 ", Byte, nil, Shape{2, 2}, 1, Shape{2, 4}, []byte{0, 1, 0, 1, 2, 3, 2, 3}}, + {"vector", Byte, nil, nil, Shape{2}, nil, 0, Shape{4}, []byte{0, 1, 0, 1}}, + {"matrix; axis 0 ", Byte, nil, nil, Shape{2, 2}, nil, 0, Shape{4, 2}, []byte{0, 1, 2, 3, 0, 1, 2, 3}}, + {"matrix; axis 1 ", Byte, nil, nil, Shape{2, 2}, nil, 1, Shape{2, 4}, []byte{0, 1, 0, 1, 2, 3, 2, 3}}, // Bool - {"vector", Bool, []bool{true, false}, Shape{2}, 0, Shape{4}, []bool{true, false, true, false}}, - {"matrix; axis 0 ", Bool, []bool{true, false, true, false}, Shape{2, 2}, 0, Shape{4, 2}, []bool{true, false, true, false, true, false, true, false}}, - {"matrix; axis 1 ", Bool, []bool{true, false, true, false}, Shape{2, 2}, 1, Shape{2, 4}, []bool{true, false, true, false, true, false, true, false}}, + {"vector", Bool, []bool{true, false}, nil, Shape{2}, nil, 0, Shape{4}, []bool{true, false, true, false}}, + {"matrix; axis 0 ", Bool, []bool{true, false, true, false}, nil, Shape{2, 2}, nil, 0, Shape{4, 2}, []bool{true, false, true, false, true, false, true, false}}, + {"matrix; axis 1 ", Bool, []bool{true, false, true, false}, nil, Shape{2, 2}, nil, 1, Shape{2, 4}, []bool{true, false, true, false, true, false, true, false}}, + + // gorgonia/gorgonia#218 related + {"matrix; axis 0", Float64, nil, nil, Shape{2, 2}, Shape{1, 2}, 0, Shape{3, 2}, []float64{0, 1, 2, 3, 0, 1}}, + {"matrix; axis 1", Float64, nil, nil, Shape{2, 2}, Shape{2, 1}, 1, Shape{2, 3}, []float64{0, 1, 0, 2, 3, 1}}, + {"colvec matrix, axis 0", Float64, nil, nil, Shape{2, 1}, Shape{1, 1}, 0, Shape{3, 1}, []float64{0, 1, 0}}, + {"rowvec matrix, axis 1", Float64, nil, nil, Shape{1, 2}, Shape{1, 1}, 1, Shape{1, 3}, []float64{0, 1, 0}}, + + {"3tensor; axis 0", Float64, nil, nil, Shape{2, 3, 2}, Shape{1, 3, 2}, 0, Shape{3, 3, 2}, []float64{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 0, 1, 2, 3, 4, 5}}, + {"3tensor; axis 2", Float64, nil, nil, Shape{2, 3, 2}, Shape{2, 3, 1}, 2, Shape{2, 3, 3}, []float64{0, 1, 0, 2, 3, 1, 4, 5, 2, 6, 7, 3, 8, 9, 4, 10, 11, 5}}, + {"3tensor; axis 1", Float64, nil, nil, Shape{2, 3, 2}, Shape{2, 1, 2}, 1, Shape{2, 4, 2}, []float64{0, 1, 2, 3, 4, 5, 0, 1, 6, 7, 8, 9, 10, 11, 2, 3}}, } func TestDense_Concat(t *testing.T) { @@ -676,17 +840,27 @@ func TestDense_Concat(t *testing.T) { if cts.a == nil { T0 = New(WithShape(cts.shape...), WithBacking(Range(cts.dt, 0, cts.shape.TotalSize()))) - T1 = New(WithShape(cts.shape...), WithBacking(Range(cts.dt, 0, cts.shape.TotalSize()))) } else { T0 = New(WithShape(cts.shape...), WithBacking(cts.a)) + } + + switch { + case cts.shapeB == nil && cts.a == nil: + T1 = New(WithShape(cts.shape...), WithBacking(Range(cts.dt, 0, cts.shape.TotalSize()))) + case cts.shapeB == nil && cts.a != nil: T1 = New(WithShape(cts.shape...), WithBacking(cloneArray(cts.a))) + case cts.shapeB != nil && cts.b == nil: + T1 = New(WithShape(cts.shapeB...), WithBacking(Range(cts.dt, 0, cts.shapeB.TotalSize()))) + case cts.shapeB != nil && cts.b != nil: + T1 = New(WithShape(cts.shapeB...), WithBacking(cts.b)) } T2, err := T0.Concat(cts.axis, T1) if err != nil { - t.Error(err) + t.Errorf("Test %v failed: %v", cts.name, err) continue } + assert.True(cts.correctShape.Eq(T2.Shape())) assert.Equal(cts.correctData, T2.Data()) } @@ -694,24 +868,31 @@ func TestDense_Concat(t *testing.T) { //Masked case for _, cts := range concatTests { - var T0, T1 *Dense if cts.a == nil { T0 = New(WithShape(cts.shape...), WithBacking(Range(cts.dt, 0, cts.shape.TotalSize()))) T0.MaskedEqual(castToDt(0.0, cts.dt)) - T1 = New(WithShape(cts.shape...), WithBacking(Range(cts.dt, 0, cts.shape.TotalSize()))) - T1.MaskedEqual(castToDt(0.0, cts.dt)) } else { T0 = New(WithShape(cts.shape...), WithBacking(cts.a)) T0.MaskedEqual(castToDt(0.0, cts.dt)) + } + + switch { + case cts.shapeB == nil && cts.a == nil: + T1 = New(WithShape(cts.shape...), WithBacking(Range(cts.dt, 0, cts.shape.TotalSize()))) + case cts.shapeB == nil && cts.a != nil: T1 = New(WithShape(cts.shape...), WithBacking(cloneArray(cts.a))) - T1.MaskedEqual(castToDt(0.0, cts.dt)) + case cts.shapeB != nil && cts.b == nil: + T1 = New(WithShape(cts.shapeB...), WithBacking(Range(cts.dt, 0, cts.shapeB.TotalSize()))) + case cts.shapeB != nil && cts.b != nil: + T1 = New(WithShape(cts.shapeB...), WithBacking(cts.b)) } + T1.MaskedEqual(castToDt(0.0, cts.dt)) T2, err := T0.Concat(cts.axis, T1) if err != nil { - t.Error(err) + t.Errorf("Test %v failed: %v", cts.name, err) continue } @@ -724,9 +905,36 @@ func TestDense_Concat(t *testing.T) { } } +func TestDense_Concat_sliced(t *testing.T) { + v := New( + WithShape(1, 5), + WithBacking([]float64{0, 1, 2, 3, 4}), + ) + cols := make([]Tensor, v.Shape().TotalSize()) + for i := 0; i < v.Shape().TotalSize(); i++ { + sliced, err := v.Slice(nil, ss(i)) + if err != nil { + t.Fatalf("Failed to slice %d. Error: %v", i, err) + } + if err = sliced.Reshape(sliced.Shape().TotalSize(), 1); err != nil { + t.Fatalf("Failed to reshape %d. Error %v", i, err) + } + cols[i] = sliced + } + result, err := Concat(1, cols[0], cols[1:]...) + if err != nil { + t.Error(err) + } + assert.Equal(t, v.Data(), result.Data()) + if v.Uintptr() == result.Uintptr() { + t.Error("They should not share the same backing data!") + } + +} + var simpleStackTests = []struct { name string - dt Dtype + dt dtype.Dtype shape Shape axis int stackCount int @@ -777,7 +985,7 @@ var simpleStackTests = []struct { var viewStackTests = []struct { name string - dt Dtype + dt dtype.Dtype shape Shape transform []int slices []Slice @@ -834,12 +1042,12 @@ func TestDense_Stack(t *testing.T) { T := New(WithShape(sts.shape...), WithBacking(Range(sts.dt, 0, sts.shape.TotalSize()))) switch { case sts.slices != nil && sts.transform == nil: - var sliced Tensor + var sliced View if sliced, err = T.Slice(sts.slices...); err != nil { t.Error(err) continue } - T = sliced.(*Dense) + T = MustGetDense(sliced) case sts.transform != nil && sts.slices == nil: T.T(sts.transform...) } @@ -850,12 +1058,12 @@ func TestDense_Stack(t *testing.T) { T1 := New(WithShape(sts.shape...), WithBacking(Range(sts.dt, offset, sts.shape.TotalSize()+offset))) switch { case sts.slices != nil && sts.transform == nil: - var sliced Tensor + var sliced View if sliced, err = T1.Slice(sts.slices...); err != nil { t.Error(err) continue } - T1 = sliced.(*Dense) + T1 = MustGetDense(sliced) case sts.transform != nil && sts.slices == nil: T1.T(sts.transform...) } @@ -901,12 +1109,12 @@ func TestDense_Stack(t *testing.T) { T := New(WithShape(sts.shape...), WithBacking(Range(sts.dt, 0, sts.shape.TotalSize()))) switch { case sts.slices != nil && sts.transform == nil: - var sliced Tensor + var sliced View if sliced, err = T.Slice(sts.slices...); err != nil { t.Error(err) continue } - T = sliced.(*Dense) + T = MustGetDense(sliced) case sts.transform != nil && sts.slices == nil: T.T(sts.transform...) } @@ -918,12 +1126,12 @@ func TestDense_Stack(t *testing.T) { T1.MaskedInside(castToDt(102.0, sts.dt), castToDt(225.0, sts.dt)) switch { case sts.slices != nil && sts.transform == nil: - var sliced Tensor + var sliced View if sliced, err = T1.Slice(sts.slices...); err != nil { t.Error(err) continue } - T1 = sliced.(*Dense) + T1 = MustGetDense(sliced) case sts.transform != nil && sts.slices == nil: T1.T(sts.transform...) } @@ -951,12 +1159,12 @@ func TestDense_Stack(t *testing.T) { var stacked []*Dense for i := 0; i < 1; i++ { T1 := New(WithShape(2, 2), WithBacking([]string{"blah1", "blah2", "blah3", "blah4"})) - var sliced Tensor + var sliced View if sliced, err = T1.Slice(nil, nil); err != nil { t.Error(err) break } - T1 = sliced.(*Dense) + T1 = MustGetDense(sliced) stacked = append(stacked, T1) } T2, err := T.Stack(0, stacked...) diff --git a/dense_norms.go b/dense_norms.go index ad75c0f..63d460a 100644 --- a/dense_norms.go +++ b/dense_norms.go @@ -94,8 +94,8 @@ func (t *Dense) Norm(ord NormOrder, axes ...int) (retVal *Dense, err error) { if len(axes) == 0 { if ord.IsUnordered() || (ord.IsFrobenius() && dims == 2) || (ord == Norm(2) && dims == 1) { backup := t.AP - ap := BorrowAP(1) - defer ReturnAP(ap) + ap := makeAP(1) + defer ap.zero() ap.unlock() ap.SetShape(t.Size()) diff --git a/dense_norms_test.go b/dense_norms_test.go index 316b32a..69879ee 100644 --- a/dense_norms_test.go +++ b/dense_norms_test.go @@ -120,12 +120,13 @@ func TestTensor_Norm(t *testing.T) { t.Error(err) } } + } func TestTensor_Norm_Axis(t *testing.T) { assert := assert.New(t) var T, s, expected, retVal *Dense - var sliced Tensor + var sliced View var err error var backing []float64 var ords []NormOrder @@ -149,7 +150,7 @@ func TestTensor_Norm_Axis(t *testing.T) { var expecteds []*Dense for k := 0; k < T.Shape()[1]; k++ { sliced, _ = T.Slice(nil, ss(k)) - s = sliced.(View).Materialize().(*Dense) + s = sliced.Materialize().(*Dense) expected, _ = s.Norm(ord) expecteds = append(expecteds, expected) } @@ -162,8 +163,8 @@ func TestTensor_Norm_Axis(t *testing.T) { assert.Equal(len(expecteds), retVal.Shape()[0]) for i, e := range expecteds { sliced, _ = retVal.Slice(ss(i)) - sliced = sliced.(View).Materialize() - if !allClose(e.Data(), sliced.Data()) { + mat := sliced.Materialize() + if !allClose(e.Data(), mat.Data()) { t.Errorf("Axis = 0; Ord = %v; Expected %v. Got %v instead. ret %v, i: %d", ord, e.Data(), sliced.Data(), retVal, i) } } @@ -173,7 +174,7 @@ func TestTensor_Norm_Axis(t *testing.T) { expecteds = expecteds[:0] for k := 0; k < T.Shape()[0]; k++ { sliced, _ = T.Slice(ss(k)) - s = sliced.(*Dense) + s = MustGetDense(sliced) expected, _ = s.Norm(ord) expecteds = append(expecteds, expected) } @@ -185,8 +186,8 @@ func TestTensor_Norm_Axis(t *testing.T) { assert.Equal(len(expecteds), retVal.Shape()[0]) for i, e := range expecteds { sliced, _ = retVal.Slice(ss(i)) - sliced = sliced.(View).Materialize().(*Dense) - if !allClose(e.Data(), sliced.Data()) { + mat := sliced.Materialize() + if !allClose(e.Data(), mat.Data()) { t.Errorf("Axis = 1; Ord = %v; Expected %v. Got %v instead", ord, e.Data(), sliced.Data()) } } @@ -249,9 +250,8 @@ func TestTensor_Norm_Axis(t *testing.T) { if rowAxis > colAxis { sliced.T() } - sliced = sliced.(View).Materialize().(*Dense) - s = sliced.(*Dense) - expected, _ = s.Norm(ord) + mat := sliced.Materialize().(*Dense) + expected, _ = mat.Norm(ord) expecteds = append(expecteds, expected) } diff --git a/dense_reduction_methods.go b/dense_reduction_methods.go index cb744b5..28058a2 100644 --- a/dense_reduction_methods.go +++ b/dense_reduction_methods.go @@ -3,37 +3,65 @@ package tensor import "github.com/pkg/errors" func (t *Dense) Sum(along ...int) (retVal *Dense, err error) { - var e Engine = t.e + e := t.Engine() + ctx := ctxFromEngine(e) if sumer, ok := e.(Sumer); ok { var ret Tensor - if ret, err = sumer.Sum(t, along...); err != nil { + if ret, err = sumer.Sum(ctx, t, along...); err != nil { return } - return ret.(*Dense), nil + if retVal, err = assertDense(ret); err != nil { + return nil, errors.Wrapf(err, opFail, "Sum") + } + return } return nil, errors.Errorf("Engine does not support Sum") } +func (t *Dense) Prod(along ...int) (retVal *Dense, err error) { + e := t.Engine() + ctx := ctxFromEngine(e) + if sumer, ok := e.(Proder); ok { + var ret Tensor + if ret, err = sumer.Prod(ctx, t, along...); err != nil { + return + } + if retVal, err = assertDense(ret); err != nil { + return nil, errors.Wrapf(err, opFail, "Prod") + } + return + } + return nil, errors.Errorf("Engine does not support Prod") +} + func (t *Dense) Max(along ...int) (retVal *Dense, err error) { - var e Engine = t.e + e := t.Engine() + ctx := ctxFromEngine(e) if maxer, ok := e.(Maxer); ok { var ret Tensor - if ret, err = maxer.Max(t, along...); err != nil { + if ret, err = maxer.Max(ctx, t, along...); err != nil { return } - return ret.(*Dense), nil + if retVal, err = assertDense(ret); err != nil { + return nil, errors.Wrapf(err, opFail, "Max") + } + return } return nil, errors.Errorf("Engine does not support Max") } func (t *Dense) Min(along ...int) (retVal *Dense, err error) { - var e Engine = t.e + e := t.Engine() + ctx := ctxFromEngine(e) if miner, ok := e.(Miner); ok { var ret Tensor - if ret, err = miner.Min(t, along...); err != nil { + if ret, err = miner.Min(ctx, t, along...); err != nil { return } - return ret.(*Dense), nil + if retVal, err = assertDense(ret); err != nil { + return nil, errors.Wrapf(err, opFail, "Min") + } + return } return nil, errors.Errorf("Engine does not support Min") } diff --git a/dense_reduction_test.go b/dense_reduction_test.go index ffe673d..e4ef5ec 100644 --- a/dense_reduction_test.go +++ b/dense_reduction_test.go @@ -1,16 +1,17 @@ -// Code generated by genlib2. DO NOT EDIT. - package tensor import ( "testing" "github.com/stretchr/testify/assert" + "gorgonia.org/dtype" "gorgonia.org/tensor/internal/execution" ) +// Code generated by genlib2. DO NOT EDIT. + var denseReductionTests = []struct { - of Dtype + of dtype.Dtype fn interface{} def interface{} axis int @@ -116,7 +117,7 @@ func TestDense_Reduce(t *testing.T) { var sumTests = []struct { name string - of Dtype + of dtype.Dtype shape Shape along []int @@ -129,84 +130,126 @@ var sumTests = []struct { {"A.Sum(0,1) for int", Int, Shape{2, 3}, []int{0, 1}, ScalarShape(), int(15)}, {"A.Sum(1,0) for int", Int, Shape{2, 3}, []int{1, 0}, ScalarShape(), int(15)}, {"3T.Sum(1,2) for int", Int, Shape{2, 3, 4}, []int{1, 2}, Shape{2}, []int{66, 210}}, + {"4T.Sum() for int", Int, Shape{2, 2, 2, 2}, []int{}, ScalarShape(), int(120)}, + {"4T.Sum(1,3) for int", Int, Shape{2, 2, 2, 2}, []int{1, 3}, Shape{2, 2}, []int{10, 18, 42, 50}}, + {"4T.Sum(0, 2, 3) for int", Int, Shape{2, 2, 2, 2}, []int{0, 2, 3}, Shape{2}, []int{44, 76}}, {"common case: T.Sum() for int8", Int8, Shape{2, 3}, []int{}, ScalarShape(), int8(15)}, {"A.Sum(0) for int8", Int8, Shape{2, 3}, []int{0}, Shape{3}, []int8{3, 5, 7}}, {"A.Sum(1) for int8", Int8, Shape{2, 3}, []int{1}, Shape{2}, []int8{3, 12}}, {"A.Sum(0,1) for int8", Int8, Shape{2, 3}, []int{0, 1}, ScalarShape(), int8(15)}, {"A.Sum(1,0) for int8", Int8, Shape{2, 3}, []int{1, 0}, ScalarShape(), int8(15)}, {"3T.Sum(1,2) for int8", Int8, Shape{2, 3, 4}, []int{1, 2}, Shape{2}, []int8{66, -46}}, + {"4T.Sum() for int8", Int8, Shape{2, 2, 2, 2}, []int{}, ScalarShape(), int8(120)}, + {"4T.Sum(1,3) for int8", Int8, Shape{2, 2, 2, 2}, []int{1, 3}, Shape{2, 2}, []int8{10, 18, 42, 50}}, + {"4T.Sum(0, 2, 3) for int8", Int8, Shape{2, 2, 2, 2}, []int{0, 2, 3}, Shape{2}, []int8{44, 76}}, {"common case: T.Sum() for int16", Int16, Shape{2, 3}, []int{}, ScalarShape(), int16(15)}, {"A.Sum(0) for int16", Int16, Shape{2, 3}, []int{0}, Shape{3}, []int16{3, 5, 7}}, {"A.Sum(1) for int16", Int16, Shape{2, 3}, []int{1}, Shape{2}, []int16{3, 12}}, {"A.Sum(0,1) for int16", Int16, Shape{2, 3}, []int{0, 1}, ScalarShape(), int16(15)}, {"A.Sum(1,0) for int16", Int16, Shape{2, 3}, []int{1, 0}, ScalarShape(), int16(15)}, {"3T.Sum(1,2) for int16", Int16, Shape{2, 3, 4}, []int{1, 2}, Shape{2}, []int16{66, 210}}, + {"4T.Sum() for int16", Int16, Shape{2, 2, 2, 2}, []int{}, ScalarShape(), int16(120)}, + {"4T.Sum(1,3) for int16", Int16, Shape{2, 2, 2, 2}, []int{1, 3}, Shape{2, 2}, []int16{10, 18, 42, 50}}, + {"4T.Sum(0, 2, 3) for int16", Int16, Shape{2, 2, 2, 2}, []int{0, 2, 3}, Shape{2}, []int16{44, 76}}, {"common case: T.Sum() for int32", Int32, Shape{2, 3}, []int{}, ScalarShape(), int32(15)}, {"A.Sum(0) for int32", Int32, Shape{2, 3}, []int{0}, Shape{3}, []int32{3, 5, 7}}, {"A.Sum(1) for int32", Int32, Shape{2, 3}, []int{1}, Shape{2}, []int32{3, 12}}, {"A.Sum(0,1) for int32", Int32, Shape{2, 3}, []int{0, 1}, ScalarShape(), int32(15)}, {"A.Sum(1,0) for int32", Int32, Shape{2, 3}, []int{1, 0}, ScalarShape(), int32(15)}, {"3T.Sum(1,2) for int32", Int32, Shape{2, 3, 4}, []int{1, 2}, Shape{2}, []int32{66, 210}}, + {"4T.Sum() for int32", Int32, Shape{2, 2, 2, 2}, []int{}, ScalarShape(), int32(120)}, + {"4T.Sum(1,3) for int32", Int32, Shape{2, 2, 2, 2}, []int{1, 3}, Shape{2, 2}, []int32{10, 18, 42, 50}}, + {"4T.Sum(0, 2, 3) for int32", Int32, Shape{2, 2, 2, 2}, []int{0, 2, 3}, Shape{2}, []int32{44, 76}}, {"common case: T.Sum() for int64", Int64, Shape{2, 3}, []int{}, ScalarShape(), int64(15)}, {"A.Sum(0) for int64", Int64, Shape{2, 3}, []int{0}, Shape{3}, []int64{3, 5, 7}}, {"A.Sum(1) for int64", Int64, Shape{2, 3}, []int{1}, Shape{2}, []int64{3, 12}}, {"A.Sum(0,1) for int64", Int64, Shape{2, 3}, []int{0, 1}, ScalarShape(), int64(15)}, {"A.Sum(1,0) for int64", Int64, Shape{2, 3}, []int{1, 0}, ScalarShape(), int64(15)}, {"3T.Sum(1,2) for int64", Int64, Shape{2, 3, 4}, []int{1, 2}, Shape{2}, []int64{66, 210}}, + {"4T.Sum() for int64", Int64, Shape{2, 2, 2, 2}, []int{}, ScalarShape(), int64(120)}, + {"4T.Sum(1,3) for int64", Int64, Shape{2, 2, 2, 2}, []int{1, 3}, Shape{2, 2}, []int64{10, 18, 42, 50}}, + {"4T.Sum(0, 2, 3) for int64", Int64, Shape{2, 2, 2, 2}, []int{0, 2, 3}, Shape{2}, []int64{44, 76}}, {"common case: T.Sum() for uint", Uint, Shape{2, 3}, []int{}, ScalarShape(), uint(15)}, {"A.Sum(0) for uint", Uint, Shape{2, 3}, []int{0}, Shape{3}, []uint{3, 5, 7}}, {"A.Sum(1) for uint", Uint, Shape{2, 3}, []int{1}, Shape{2}, []uint{3, 12}}, {"A.Sum(0,1) for uint", Uint, Shape{2, 3}, []int{0, 1}, ScalarShape(), uint(15)}, {"A.Sum(1,0) for uint", Uint, Shape{2, 3}, []int{1, 0}, ScalarShape(), uint(15)}, {"3T.Sum(1,2) for uint", Uint, Shape{2, 3, 4}, []int{1, 2}, Shape{2}, []uint{66, 210}}, + {"4T.Sum() for uint", Uint, Shape{2, 2, 2, 2}, []int{}, ScalarShape(), uint(120)}, + {"4T.Sum(1,3) for uint", Uint, Shape{2, 2, 2, 2}, []int{1, 3}, Shape{2, 2}, []uint{10, 18, 42, 50}}, + {"4T.Sum(0, 2, 3) for uint", Uint, Shape{2, 2, 2, 2}, []int{0, 2, 3}, Shape{2}, []uint{44, 76}}, {"common case: T.Sum() for uint8", Uint8, Shape{2, 3}, []int{}, ScalarShape(), uint8(15)}, {"A.Sum(0) for uint8", Uint8, Shape{2, 3}, []int{0}, Shape{3}, []uint8{3, 5, 7}}, {"A.Sum(1) for uint8", Uint8, Shape{2, 3}, []int{1}, Shape{2}, []uint8{3, 12}}, {"A.Sum(0,1) for uint8", Uint8, Shape{2, 3}, []int{0, 1}, ScalarShape(), uint8(15)}, {"A.Sum(1,0) for uint8", Uint8, Shape{2, 3}, []int{1, 0}, ScalarShape(), uint8(15)}, {"3T.Sum(1,2) for uint8", Uint8, Shape{2, 3, 4}, []int{1, 2}, Shape{2}, []uint8{66, 210}}, + {"4T.Sum() for uint8", Uint8, Shape{2, 2, 2, 2}, []int{}, ScalarShape(), uint8(120)}, + {"4T.Sum(1,3) for uint8", Uint8, Shape{2, 2, 2, 2}, []int{1, 3}, Shape{2, 2}, []uint8{10, 18, 42, 50}}, + {"4T.Sum(0, 2, 3) for uint8", Uint8, Shape{2, 2, 2, 2}, []int{0, 2, 3}, Shape{2}, []uint8{44, 76}}, {"common case: T.Sum() for uint16", Uint16, Shape{2, 3}, []int{}, ScalarShape(), uint16(15)}, {"A.Sum(0) for uint16", Uint16, Shape{2, 3}, []int{0}, Shape{3}, []uint16{3, 5, 7}}, {"A.Sum(1) for uint16", Uint16, Shape{2, 3}, []int{1}, Shape{2}, []uint16{3, 12}}, {"A.Sum(0,1) for uint16", Uint16, Shape{2, 3}, []int{0, 1}, ScalarShape(), uint16(15)}, {"A.Sum(1,0) for uint16", Uint16, Shape{2, 3}, []int{1, 0}, ScalarShape(), uint16(15)}, {"3T.Sum(1,2) for uint16", Uint16, Shape{2, 3, 4}, []int{1, 2}, Shape{2}, []uint16{66, 210}}, + {"4T.Sum() for uint16", Uint16, Shape{2, 2, 2, 2}, []int{}, ScalarShape(), uint16(120)}, + {"4T.Sum(1,3) for uint16", Uint16, Shape{2, 2, 2, 2}, []int{1, 3}, Shape{2, 2}, []uint16{10, 18, 42, 50}}, + {"4T.Sum(0, 2, 3) for uint16", Uint16, Shape{2, 2, 2, 2}, []int{0, 2, 3}, Shape{2}, []uint16{44, 76}}, {"common case: T.Sum() for uint32", Uint32, Shape{2, 3}, []int{}, ScalarShape(), uint32(15)}, {"A.Sum(0) for uint32", Uint32, Shape{2, 3}, []int{0}, Shape{3}, []uint32{3, 5, 7}}, {"A.Sum(1) for uint32", Uint32, Shape{2, 3}, []int{1}, Shape{2}, []uint32{3, 12}}, {"A.Sum(0,1) for uint32", Uint32, Shape{2, 3}, []int{0, 1}, ScalarShape(), uint32(15)}, {"A.Sum(1,0) for uint32", Uint32, Shape{2, 3}, []int{1, 0}, ScalarShape(), uint32(15)}, {"3T.Sum(1,2) for uint32", Uint32, Shape{2, 3, 4}, []int{1, 2}, Shape{2}, []uint32{66, 210}}, + {"4T.Sum() for uint32", Uint32, Shape{2, 2, 2, 2}, []int{}, ScalarShape(), uint32(120)}, + {"4T.Sum(1,3) for uint32", Uint32, Shape{2, 2, 2, 2}, []int{1, 3}, Shape{2, 2}, []uint32{10, 18, 42, 50}}, + {"4T.Sum(0, 2, 3) for uint32", Uint32, Shape{2, 2, 2, 2}, []int{0, 2, 3}, Shape{2}, []uint32{44, 76}}, {"common case: T.Sum() for uint64", Uint64, Shape{2, 3}, []int{}, ScalarShape(), uint64(15)}, {"A.Sum(0) for uint64", Uint64, Shape{2, 3}, []int{0}, Shape{3}, []uint64{3, 5, 7}}, {"A.Sum(1) for uint64", Uint64, Shape{2, 3}, []int{1}, Shape{2}, []uint64{3, 12}}, {"A.Sum(0,1) for uint64", Uint64, Shape{2, 3}, []int{0, 1}, ScalarShape(), uint64(15)}, {"A.Sum(1,0) for uint64", Uint64, Shape{2, 3}, []int{1, 0}, ScalarShape(), uint64(15)}, {"3T.Sum(1,2) for uint64", Uint64, Shape{2, 3, 4}, []int{1, 2}, Shape{2}, []uint64{66, 210}}, + {"4T.Sum() for uint64", Uint64, Shape{2, 2, 2, 2}, []int{}, ScalarShape(), uint64(120)}, + {"4T.Sum(1,3) for uint64", Uint64, Shape{2, 2, 2, 2}, []int{1, 3}, Shape{2, 2}, []uint64{10, 18, 42, 50}}, + {"4T.Sum(0, 2, 3) for uint64", Uint64, Shape{2, 2, 2, 2}, []int{0, 2, 3}, Shape{2}, []uint64{44, 76}}, {"common case: T.Sum() for float32", Float32, Shape{2, 3}, []int{}, ScalarShape(), float32(15)}, {"A.Sum(0) for float32", Float32, Shape{2, 3}, []int{0}, Shape{3}, []float32{3, 5, 7}}, {"A.Sum(1) for float32", Float32, Shape{2, 3}, []int{1}, Shape{2}, []float32{3, 12}}, {"A.Sum(0,1) for float32", Float32, Shape{2, 3}, []int{0, 1}, ScalarShape(), float32(15)}, {"A.Sum(1,0) for float32", Float32, Shape{2, 3}, []int{1, 0}, ScalarShape(), float32(15)}, {"3T.Sum(1,2) for float32", Float32, Shape{2, 3, 4}, []int{1, 2}, Shape{2}, []float32{66, 210}}, + {"4T.Sum() for float32", Float32, Shape{2, 2, 2, 2}, []int{}, ScalarShape(), float32(120)}, + {"4T.Sum(1,3) for float32", Float32, Shape{2, 2, 2, 2}, []int{1, 3}, Shape{2, 2}, []float32{10, 18, 42, 50}}, + {"4T.Sum(0, 2, 3) for float32", Float32, Shape{2, 2, 2, 2}, []int{0, 2, 3}, Shape{2}, []float32{44, 76}}, {"common case: T.Sum() for float64", Float64, Shape{2, 3}, []int{}, ScalarShape(), float64(15)}, {"A.Sum(0) for float64", Float64, Shape{2, 3}, []int{0}, Shape{3}, []float64{3, 5, 7}}, {"A.Sum(1) for float64", Float64, Shape{2, 3}, []int{1}, Shape{2}, []float64{3, 12}}, {"A.Sum(0,1) for float64", Float64, Shape{2, 3}, []int{0, 1}, ScalarShape(), float64(15)}, {"A.Sum(1,0) for float64", Float64, Shape{2, 3}, []int{1, 0}, ScalarShape(), float64(15)}, {"3T.Sum(1,2) for float64", Float64, Shape{2, 3, 4}, []int{1, 2}, Shape{2}, []float64{66, 210}}, + {"4T.Sum() for float64", Float64, Shape{2, 2, 2, 2}, []int{}, ScalarShape(), float64(120)}, + {"4T.Sum(1,3) for float64", Float64, Shape{2, 2, 2, 2}, []int{1, 3}, Shape{2, 2}, []float64{10, 18, 42, 50}}, + {"4T.Sum(0, 2, 3) for float64", Float64, Shape{2, 2, 2, 2}, []int{0, 2, 3}, Shape{2}, []float64{44, 76}}, {"common case: T.Sum() for complex64", Complex64, Shape{2, 3}, []int{}, ScalarShape(), complex64(15)}, {"A.Sum(0) for complex64", Complex64, Shape{2, 3}, []int{0}, Shape{3}, []complex64{3, 5, 7}}, {"A.Sum(1) for complex64", Complex64, Shape{2, 3}, []int{1}, Shape{2}, []complex64{3, 12}}, {"A.Sum(0,1) for complex64", Complex64, Shape{2, 3}, []int{0, 1}, ScalarShape(), complex64(15)}, {"A.Sum(1,0) for complex64", Complex64, Shape{2, 3}, []int{1, 0}, ScalarShape(), complex64(15)}, {"3T.Sum(1,2) for complex64", Complex64, Shape{2, 3, 4}, []int{1, 2}, Shape{2}, []complex64{66, 210}}, + {"4T.Sum() for complex64", Complex64, Shape{2, 2, 2, 2}, []int{}, ScalarShape(), complex64(120)}, + {"4T.Sum(1,3) for complex64", Complex64, Shape{2, 2, 2, 2}, []int{1, 3}, Shape{2, 2}, []complex64{10, 18, 42, 50}}, + {"4T.Sum(0, 2, 3) for complex64", Complex64, Shape{2, 2, 2, 2}, []int{0, 2, 3}, Shape{2}, []complex64{44, 76}}, {"common case: T.Sum() for complex128", Complex128, Shape{2, 3}, []int{}, ScalarShape(), complex128(15)}, {"A.Sum(0) for complex128", Complex128, Shape{2, 3}, []int{0}, Shape{3}, []complex128{3, 5, 7}}, {"A.Sum(1) for complex128", Complex128, Shape{2, 3}, []int{1}, Shape{2}, []complex128{3, 12}}, {"A.Sum(0,1) for complex128", Complex128, Shape{2, 3}, []int{0, 1}, ScalarShape(), complex128(15)}, {"A.Sum(1,0) for complex128", Complex128, Shape{2, 3}, []int{1, 0}, ScalarShape(), complex128(15)}, {"3T.Sum(1,2) for complex128", Complex128, Shape{2, 3, 4}, []int{1, 2}, Shape{2}, []complex128{66, 210}}, + {"4T.Sum() for complex128", Complex128, Shape{2, 2, 2, 2}, []int{}, ScalarShape(), complex128(120)}, + {"4T.Sum(1,3) for complex128", Complex128, Shape{2, 2, 2, 2}, []int{1, 3}, Shape{2, 2}, []complex128{10, 18, 42, 50}}, + {"4T.Sum(0, 2, 3) for complex128", Complex128, Shape{2, 2, 2, 2}, []int{0, 2, 3}, Shape{2}, []complex128{44, 76}}, } func TestDense_Sum(t *testing.T) { @@ -231,7 +274,7 @@ func TestDense_Sum(t *testing.T) { var maxTests = []struct { name string - of Dtype + of dtype.Dtype shape Shape along []int @@ -244,72 +287,108 @@ var maxTests = []struct { {"A.Max(0,1)", Int, Shape{2, 3}, []int{0, 1}, ScalarShape(), int(5)}, {"A.Max(1,0)", Int, Shape{2, 3}, []int{1, 0}, ScalarShape(), int(5)}, {"3T.Max(1,2)", Int, Shape{2, 3, 4}, []int{1, 2}, Shape{2}, []int{11, 23}}, + {"4T.Max()", Int, Shape{2, 2, 2, 2}, []int{}, ScalarShape(), int(15)}, + {"4T.Max(1,3)", Int, Shape{2, 2, 2, 2}, []int{1, 3}, Shape{2, 2}, []int{5, 7, 13, 15}}, + {"4T.Max(0, 2, 3)", Int, Shape{2, 2, 2, 2}, []int{0, 2, 3}, Shape{2}, []int{11, 15}}, {"common case: T.Max() for int8", Int8, Shape{2, 3}, []int{}, ScalarShape(), int8(5)}, {"A.Max(0)", Int8, Shape{2, 3}, []int{0}, Shape{3}, []int8{3, 4, 5}}, {"A.Max(1)", Int8, Shape{2, 3}, []int{1}, Shape{2}, []int8{2, 5}}, {"A.Max(0,1)", Int8, Shape{2, 3}, []int{0, 1}, ScalarShape(), int8(5)}, {"A.Max(1,0)", Int8, Shape{2, 3}, []int{1, 0}, ScalarShape(), int8(5)}, {"3T.Max(1,2)", Int8, Shape{2, 3, 4}, []int{1, 2}, Shape{2}, []int8{11, 23}}, + {"4T.Max()", Int8, Shape{2, 2, 2, 2}, []int{}, ScalarShape(), int8(15)}, + {"4T.Max(1,3)", Int8, Shape{2, 2, 2, 2}, []int{1, 3}, Shape{2, 2}, []int8{5, 7, 13, 15}}, + {"4T.Max(0, 2, 3)", Int8, Shape{2, 2, 2, 2}, []int{0, 2, 3}, Shape{2}, []int8{11, 15}}, {"common case: T.Max() for int16", Int16, Shape{2, 3}, []int{}, ScalarShape(), int16(5)}, {"A.Max(0)", Int16, Shape{2, 3}, []int{0}, Shape{3}, []int16{3, 4, 5}}, {"A.Max(1)", Int16, Shape{2, 3}, []int{1}, Shape{2}, []int16{2, 5}}, {"A.Max(0,1)", Int16, Shape{2, 3}, []int{0, 1}, ScalarShape(), int16(5)}, {"A.Max(1,0)", Int16, Shape{2, 3}, []int{1, 0}, ScalarShape(), int16(5)}, {"3T.Max(1,2)", Int16, Shape{2, 3, 4}, []int{1, 2}, Shape{2}, []int16{11, 23}}, + {"4T.Max()", Int16, Shape{2, 2, 2, 2}, []int{}, ScalarShape(), int16(15)}, + {"4T.Max(1,3)", Int16, Shape{2, 2, 2, 2}, []int{1, 3}, Shape{2, 2}, []int16{5, 7, 13, 15}}, + {"4T.Max(0, 2, 3)", Int16, Shape{2, 2, 2, 2}, []int{0, 2, 3}, Shape{2}, []int16{11, 15}}, {"common case: T.Max() for int32", Int32, Shape{2, 3}, []int{}, ScalarShape(), int32(5)}, {"A.Max(0)", Int32, Shape{2, 3}, []int{0}, Shape{3}, []int32{3, 4, 5}}, {"A.Max(1)", Int32, Shape{2, 3}, []int{1}, Shape{2}, []int32{2, 5}}, {"A.Max(0,1)", Int32, Shape{2, 3}, []int{0, 1}, ScalarShape(), int32(5)}, {"A.Max(1,0)", Int32, Shape{2, 3}, []int{1, 0}, ScalarShape(), int32(5)}, {"3T.Max(1,2)", Int32, Shape{2, 3, 4}, []int{1, 2}, Shape{2}, []int32{11, 23}}, + {"4T.Max()", Int32, Shape{2, 2, 2, 2}, []int{}, ScalarShape(), int32(15)}, + {"4T.Max(1,3)", Int32, Shape{2, 2, 2, 2}, []int{1, 3}, Shape{2, 2}, []int32{5, 7, 13, 15}}, + {"4T.Max(0, 2, 3)", Int32, Shape{2, 2, 2, 2}, []int{0, 2, 3}, Shape{2}, []int32{11, 15}}, {"common case: T.Max() for int64", Int64, Shape{2, 3}, []int{}, ScalarShape(), int64(5)}, {"A.Max(0)", Int64, Shape{2, 3}, []int{0}, Shape{3}, []int64{3, 4, 5}}, {"A.Max(1)", Int64, Shape{2, 3}, []int{1}, Shape{2}, []int64{2, 5}}, {"A.Max(0,1)", Int64, Shape{2, 3}, []int{0, 1}, ScalarShape(), int64(5)}, {"A.Max(1,0)", Int64, Shape{2, 3}, []int{1, 0}, ScalarShape(), int64(5)}, {"3T.Max(1,2)", Int64, Shape{2, 3, 4}, []int{1, 2}, Shape{2}, []int64{11, 23}}, + {"4T.Max()", Int64, Shape{2, 2, 2, 2}, []int{}, ScalarShape(), int64(15)}, + {"4T.Max(1,3)", Int64, Shape{2, 2, 2, 2}, []int{1, 3}, Shape{2, 2}, []int64{5, 7, 13, 15}}, + {"4T.Max(0, 2, 3)", Int64, Shape{2, 2, 2, 2}, []int{0, 2, 3}, Shape{2}, []int64{11, 15}}, {"common case: T.Max() for uint", Uint, Shape{2, 3}, []int{}, ScalarShape(), uint(5)}, {"A.Max(0)", Uint, Shape{2, 3}, []int{0}, Shape{3}, []uint{3, 4, 5}}, {"A.Max(1)", Uint, Shape{2, 3}, []int{1}, Shape{2}, []uint{2, 5}}, {"A.Max(0,1)", Uint, Shape{2, 3}, []int{0, 1}, ScalarShape(), uint(5)}, {"A.Max(1,0)", Uint, Shape{2, 3}, []int{1, 0}, ScalarShape(), uint(5)}, {"3T.Max(1,2)", Uint, Shape{2, 3, 4}, []int{1, 2}, Shape{2}, []uint{11, 23}}, + {"4T.Max()", Uint, Shape{2, 2, 2, 2}, []int{}, ScalarShape(), uint(15)}, + {"4T.Max(1,3)", Uint, Shape{2, 2, 2, 2}, []int{1, 3}, Shape{2, 2}, []uint{5, 7, 13, 15}}, + {"4T.Max(0, 2, 3)", Uint, Shape{2, 2, 2, 2}, []int{0, 2, 3}, Shape{2}, []uint{11, 15}}, {"common case: T.Max() for uint8", Uint8, Shape{2, 3}, []int{}, ScalarShape(), uint8(5)}, {"A.Max(0)", Uint8, Shape{2, 3}, []int{0}, Shape{3}, []uint8{3, 4, 5}}, {"A.Max(1)", Uint8, Shape{2, 3}, []int{1}, Shape{2}, []uint8{2, 5}}, {"A.Max(0,1)", Uint8, Shape{2, 3}, []int{0, 1}, ScalarShape(), uint8(5)}, {"A.Max(1,0)", Uint8, Shape{2, 3}, []int{1, 0}, ScalarShape(), uint8(5)}, {"3T.Max(1,2)", Uint8, Shape{2, 3, 4}, []int{1, 2}, Shape{2}, []uint8{11, 23}}, + {"4T.Max()", Uint8, Shape{2, 2, 2, 2}, []int{}, ScalarShape(), uint8(15)}, + {"4T.Max(1,3)", Uint8, Shape{2, 2, 2, 2}, []int{1, 3}, Shape{2, 2}, []uint8{5, 7, 13, 15}}, + {"4T.Max(0, 2, 3)", Uint8, Shape{2, 2, 2, 2}, []int{0, 2, 3}, Shape{2}, []uint8{11, 15}}, {"common case: T.Max() for uint16", Uint16, Shape{2, 3}, []int{}, ScalarShape(), uint16(5)}, {"A.Max(0)", Uint16, Shape{2, 3}, []int{0}, Shape{3}, []uint16{3, 4, 5}}, {"A.Max(1)", Uint16, Shape{2, 3}, []int{1}, Shape{2}, []uint16{2, 5}}, {"A.Max(0,1)", Uint16, Shape{2, 3}, []int{0, 1}, ScalarShape(), uint16(5)}, {"A.Max(1,0)", Uint16, Shape{2, 3}, []int{1, 0}, ScalarShape(), uint16(5)}, {"3T.Max(1,2)", Uint16, Shape{2, 3, 4}, []int{1, 2}, Shape{2}, []uint16{11, 23}}, + {"4T.Max()", Uint16, Shape{2, 2, 2, 2}, []int{}, ScalarShape(), uint16(15)}, + {"4T.Max(1,3)", Uint16, Shape{2, 2, 2, 2}, []int{1, 3}, Shape{2, 2}, []uint16{5, 7, 13, 15}}, + {"4T.Max(0, 2, 3)", Uint16, Shape{2, 2, 2, 2}, []int{0, 2, 3}, Shape{2}, []uint16{11, 15}}, {"common case: T.Max() for uint32", Uint32, Shape{2, 3}, []int{}, ScalarShape(), uint32(5)}, {"A.Max(0)", Uint32, Shape{2, 3}, []int{0}, Shape{3}, []uint32{3, 4, 5}}, {"A.Max(1)", Uint32, Shape{2, 3}, []int{1}, Shape{2}, []uint32{2, 5}}, {"A.Max(0,1)", Uint32, Shape{2, 3}, []int{0, 1}, ScalarShape(), uint32(5)}, {"A.Max(1,0)", Uint32, Shape{2, 3}, []int{1, 0}, ScalarShape(), uint32(5)}, {"3T.Max(1,2)", Uint32, Shape{2, 3, 4}, []int{1, 2}, Shape{2}, []uint32{11, 23}}, + {"4T.Max()", Uint32, Shape{2, 2, 2, 2}, []int{}, ScalarShape(), uint32(15)}, + {"4T.Max(1,3)", Uint32, Shape{2, 2, 2, 2}, []int{1, 3}, Shape{2, 2}, []uint32{5, 7, 13, 15}}, + {"4T.Max(0, 2, 3)", Uint32, Shape{2, 2, 2, 2}, []int{0, 2, 3}, Shape{2}, []uint32{11, 15}}, {"common case: T.Max() for uint64", Uint64, Shape{2, 3}, []int{}, ScalarShape(), uint64(5)}, {"A.Max(0)", Uint64, Shape{2, 3}, []int{0}, Shape{3}, []uint64{3, 4, 5}}, {"A.Max(1)", Uint64, Shape{2, 3}, []int{1}, Shape{2}, []uint64{2, 5}}, {"A.Max(0,1)", Uint64, Shape{2, 3}, []int{0, 1}, ScalarShape(), uint64(5)}, {"A.Max(1,0)", Uint64, Shape{2, 3}, []int{1, 0}, ScalarShape(), uint64(5)}, {"3T.Max(1,2)", Uint64, Shape{2, 3, 4}, []int{1, 2}, Shape{2}, []uint64{11, 23}}, + {"4T.Max()", Uint64, Shape{2, 2, 2, 2}, []int{}, ScalarShape(), uint64(15)}, + {"4T.Max(1,3)", Uint64, Shape{2, 2, 2, 2}, []int{1, 3}, Shape{2, 2}, []uint64{5, 7, 13, 15}}, + {"4T.Max(0, 2, 3)", Uint64, Shape{2, 2, 2, 2}, []int{0, 2, 3}, Shape{2}, []uint64{11, 15}}, {"common case: T.Max() for float32", Float32, Shape{2, 3}, []int{}, ScalarShape(), float32(5)}, {"A.Max(0)", Float32, Shape{2, 3}, []int{0}, Shape{3}, []float32{3, 4, 5}}, {"A.Max(1)", Float32, Shape{2, 3}, []int{1}, Shape{2}, []float32{2, 5}}, {"A.Max(0,1)", Float32, Shape{2, 3}, []int{0, 1}, ScalarShape(), float32(5)}, {"A.Max(1,0)", Float32, Shape{2, 3}, []int{1, 0}, ScalarShape(), float32(5)}, {"3T.Max(1,2)", Float32, Shape{2, 3, 4}, []int{1, 2}, Shape{2}, []float32{11, 23}}, + {"4T.Max()", Float32, Shape{2, 2, 2, 2}, []int{}, ScalarShape(), float32(15)}, + {"4T.Max(1,3)", Float32, Shape{2, 2, 2, 2}, []int{1, 3}, Shape{2, 2}, []float32{5, 7, 13, 15}}, + {"4T.Max(0, 2, 3)", Float32, Shape{2, 2, 2, 2}, []int{0, 2, 3}, Shape{2}, []float32{11, 15}}, {"common case: T.Max() for float64", Float64, Shape{2, 3}, []int{}, ScalarShape(), float64(5)}, {"A.Max(0)", Float64, Shape{2, 3}, []int{0}, Shape{3}, []float64{3, 4, 5}}, {"A.Max(1)", Float64, Shape{2, 3}, []int{1}, Shape{2}, []float64{2, 5}}, {"A.Max(0,1)", Float64, Shape{2, 3}, []int{0, 1}, ScalarShape(), float64(5)}, {"A.Max(1,0)", Float64, Shape{2, 3}, []int{1, 0}, ScalarShape(), float64(5)}, {"3T.Max(1,2)", Float64, Shape{2, 3, 4}, []int{1, 2}, Shape{2}, []float64{11, 23}}, + {"4T.Max()", Float64, Shape{2, 2, 2, 2}, []int{}, ScalarShape(), float64(15)}, + {"4T.Max(1,3)", Float64, Shape{2, 2, 2, 2}, []int{1, 3}, Shape{2, 2}, []float64{5, 7, 13, 15}}, + {"4T.Max(0, 2, 3)", Float64, Shape{2, 2, 2, 2}, []int{0, 2, 3}, Shape{2}, []float64{11, 15}}, } func TestDense_Max(t *testing.T) { @@ -333,7 +412,7 @@ func TestDense_Max(t *testing.T) { var minTests = []struct { name string - of Dtype + of dtype.Dtype shape Shape along []int @@ -346,72 +425,108 @@ var minTests = []struct { {"A.Min(0,1)", Int, Shape{2, 3}, []int{0, 1}, ScalarShape(), int(0)}, {"A.Min(1,0)", Int, Shape{2, 3}, []int{1, 0}, ScalarShape(), int(0)}, {"3T.Min(1,2)", Int, Shape{2, 3, 4}, []int{1, 2}, Shape{2}, []int{0, 12}}, + {"4T.Min()", Int, Shape{2, 2, 2, 2}, []int{}, ScalarShape(), int(0)}, + {"4T.Min(1,3)", Int, Shape{2, 2, 2, 2}, []int{1, 3}, Shape{2, 2}, []int{0, 2, 8, 10}}, + {"4T.Min(0, 2, 3)", Int, Shape{2, 2, 2, 2}, []int{0, 2, 3}, Shape{2}, []int{0, 4}}, {"common case: T.Min() for int8", Int8, Shape{2, 3}, []int{}, ScalarShape(), int8(0)}, {"A.Min(0)", Int8, Shape{2, 3}, []int{0}, Shape{3}, []int8{0, 1, 2}}, {"A.Min(1)", Int8, Shape{2, 3}, []int{1}, Shape{2}, []int8{0, 3}}, {"A.Min(0,1)", Int8, Shape{2, 3}, []int{0, 1}, ScalarShape(), int8(0)}, {"A.Min(1,0)", Int8, Shape{2, 3}, []int{1, 0}, ScalarShape(), int8(0)}, {"3T.Min(1,2)", Int8, Shape{2, 3, 4}, []int{1, 2}, Shape{2}, []int8{0, 12}}, + {"4T.Min()", Int8, Shape{2, 2, 2, 2}, []int{}, ScalarShape(), int8(0)}, + {"4T.Min(1,3)", Int8, Shape{2, 2, 2, 2}, []int{1, 3}, Shape{2, 2}, []int8{0, 2, 8, 10}}, + {"4T.Min(0, 2, 3)", Int8, Shape{2, 2, 2, 2}, []int{0, 2, 3}, Shape{2}, []int8{0, 4}}, {"common case: T.Min() for int16", Int16, Shape{2, 3}, []int{}, ScalarShape(), int16(0)}, {"A.Min(0)", Int16, Shape{2, 3}, []int{0}, Shape{3}, []int16{0, 1, 2}}, {"A.Min(1)", Int16, Shape{2, 3}, []int{1}, Shape{2}, []int16{0, 3}}, {"A.Min(0,1)", Int16, Shape{2, 3}, []int{0, 1}, ScalarShape(), int16(0)}, {"A.Min(1,0)", Int16, Shape{2, 3}, []int{1, 0}, ScalarShape(), int16(0)}, {"3T.Min(1,2)", Int16, Shape{2, 3, 4}, []int{1, 2}, Shape{2}, []int16{0, 12}}, + {"4T.Min()", Int16, Shape{2, 2, 2, 2}, []int{}, ScalarShape(), int16(0)}, + {"4T.Min(1,3)", Int16, Shape{2, 2, 2, 2}, []int{1, 3}, Shape{2, 2}, []int16{0, 2, 8, 10}}, + {"4T.Min(0, 2, 3)", Int16, Shape{2, 2, 2, 2}, []int{0, 2, 3}, Shape{2}, []int16{0, 4}}, {"common case: T.Min() for int32", Int32, Shape{2, 3}, []int{}, ScalarShape(), int32(0)}, {"A.Min(0)", Int32, Shape{2, 3}, []int{0}, Shape{3}, []int32{0, 1, 2}}, {"A.Min(1)", Int32, Shape{2, 3}, []int{1}, Shape{2}, []int32{0, 3}}, {"A.Min(0,1)", Int32, Shape{2, 3}, []int{0, 1}, ScalarShape(), int32(0)}, {"A.Min(1,0)", Int32, Shape{2, 3}, []int{1, 0}, ScalarShape(), int32(0)}, {"3T.Min(1,2)", Int32, Shape{2, 3, 4}, []int{1, 2}, Shape{2}, []int32{0, 12}}, + {"4T.Min()", Int32, Shape{2, 2, 2, 2}, []int{}, ScalarShape(), int32(0)}, + {"4T.Min(1,3)", Int32, Shape{2, 2, 2, 2}, []int{1, 3}, Shape{2, 2}, []int32{0, 2, 8, 10}}, + {"4T.Min(0, 2, 3)", Int32, Shape{2, 2, 2, 2}, []int{0, 2, 3}, Shape{2}, []int32{0, 4}}, {"common case: T.Min() for int64", Int64, Shape{2, 3}, []int{}, ScalarShape(), int64(0)}, {"A.Min(0)", Int64, Shape{2, 3}, []int{0}, Shape{3}, []int64{0, 1, 2}}, {"A.Min(1)", Int64, Shape{2, 3}, []int{1}, Shape{2}, []int64{0, 3}}, {"A.Min(0,1)", Int64, Shape{2, 3}, []int{0, 1}, ScalarShape(), int64(0)}, {"A.Min(1,0)", Int64, Shape{2, 3}, []int{1, 0}, ScalarShape(), int64(0)}, {"3T.Min(1,2)", Int64, Shape{2, 3, 4}, []int{1, 2}, Shape{2}, []int64{0, 12}}, + {"4T.Min()", Int64, Shape{2, 2, 2, 2}, []int{}, ScalarShape(), int64(0)}, + {"4T.Min(1,3)", Int64, Shape{2, 2, 2, 2}, []int{1, 3}, Shape{2, 2}, []int64{0, 2, 8, 10}}, + {"4T.Min(0, 2, 3)", Int64, Shape{2, 2, 2, 2}, []int{0, 2, 3}, Shape{2}, []int64{0, 4}}, {"common case: T.Min() for uint", Uint, Shape{2, 3}, []int{}, ScalarShape(), uint(0)}, {"A.Min(0)", Uint, Shape{2, 3}, []int{0}, Shape{3}, []uint{0, 1, 2}}, {"A.Min(1)", Uint, Shape{2, 3}, []int{1}, Shape{2}, []uint{0, 3}}, {"A.Min(0,1)", Uint, Shape{2, 3}, []int{0, 1}, ScalarShape(), uint(0)}, {"A.Min(1,0)", Uint, Shape{2, 3}, []int{1, 0}, ScalarShape(), uint(0)}, {"3T.Min(1,2)", Uint, Shape{2, 3, 4}, []int{1, 2}, Shape{2}, []uint{0, 12}}, + {"4T.Min()", Uint, Shape{2, 2, 2, 2}, []int{}, ScalarShape(), uint(0)}, + {"4T.Min(1,3)", Uint, Shape{2, 2, 2, 2}, []int{1, 3}, Shape{2, 2}, []uint{0, 2, 8, 10}}, + {"4T.Min(0, 2, 3)", Uint, Shape{2, 2, 2, 2}, []int{0, 2, 3}, Shape{2}, []uint{0, 4}}, {"common case: T.Min() for uint8", Uint8, Shape{2, 3}, []int{}, ScalarShape(), uint8(0)}, {"A.Min(0)", Uint8, Shape{2, 3}, []int{0}, Shape{3}, []uint8{0, 1, 2}}, {"A.Min(1)", Uint8, Shape{2, 3}, []int{1}, Shape{2}, []uint8{0, 3}}, {"A.Min(0,1)", Uint8, Shape{2, 3}, []int{0, 1}, ScalarShape(), uint8(0)}, {"A.Min(1,0)", Uint8, Shape{2, 3}, []int{1, 0}, ScalarShape(), uint8(0)}, {"3T.Min(1,2)", Uint8, Shape{2, 3, 4}, []int{1, 2}, Shape{2}, []uint8{0, 12}}, + {"4T.Min()", Uint8, Shape{2, 2, 2, 2}, []int{}, ScalarShape(), uint8(0)}, + {"4T.Min(1,3)", Uint8, Shape{2, 2, 2, 2}, []int{1, 3}, Shape{2, 2}, []uint8{0, 2, 8, 10}}, + {"4T.Min(0, 2, 3)", Uint8, Shape{2, 2, 2, 2}, []int{0, 2, 3}, Shape{2}, []uint8{0, 4}}, {"common case: T.Min() for uint16", Uint16, Shape{2, 3}, []int{}, ScalarShape(), uint16(0)}, {"A.Min(0)", Uint16, Shape{2, 3}, []int{0}, Shape{3}, []uint16{0, 1, 2}}, {"A.Min(1)", Uint16, Shape{2, 3}, []int{1}, Shape{2}, []uint16{0, 3}}, {"A.Min(0,1)", Uint16, Shape{2, 3}, []int{0, 1}, ScalarShape(), uint16(0)}, {"A.Min(1,0)", Uint16, Shape{2, 3}, []int{1, 0}, ScalarShape(), uint16(0)}, {"3T.Min(1,2)", Uint16, Shape{2, 3, 4}, []int{1, 2}, Shape{2}, []uint16{0, 12}}, + {"4T.Min()", Uint16, Shape{2, 2, 2, 2}, []int{}, ScalarShape(), uint16(0)}, + {"4T.Min(1,3)", Uint16, Shape{2, 2, 2, 2}, []int{1, 3}, Shape{2, 2}, []uint16{0, 2, 8, 10}}, + {"4T.Min(0, 2, 3)", Uint16, Shape{2, 2, 2, 2}, []int{0, 2, 3}, Shape{2}, []uint16{0, 4}}, {"common case: T.Min() for uint32", Uint32, Shape{2, 3}, []int{}, ScalarShape(), uint32(0)}, {"A.Min(0)", Uint32, Shape{2, 3}, []int{0}, Shape{3}, []uint32{0, 1, 2}}, {"A.Min(1)", Uint32, Shape{2, 3}, []int{1}, Shape{2}, []uint32{0, 3}}, {"A.Min(0,1)", Uint32, Shape{2, 3}, []int{0, 1}, ScalarShape(), uint32(0)}, {"A.Min(1,0)", Uint32, Shape{2, 3}, []int{1, 0}, ScalarShape(), uint32(0)}, {"3T.Min(1,2)", Uint32, Shape{2, 3, 4}, []int{1, 2}, Shape{2}, []uint32{0, 12}}, + {"4T.Min()", Uint32, Shape{2, 2, 2, 2}, []int{}, ScalarShape(), uint32(0)}, + {"4T.Min(1,3)", Uint32, Shape{2, 2, 2, 2}, []int{1, 3}, Shape{2, 2}, []uint32{0, 2, 8, 10}}, + {"4T.Min(0, 2, 3)", Uint32, Shape{2, 2, 2, 2}, []int{0, 2, 3}, Shape{2}, []uint32{0, 4}}, {"common case: T.Min() for uint64", Uint64, Shape{2, 3}, []int{}, ScalarShape(), uint64(0)}, {"A.Min(0)", Uint64, Shape{2, 3}, []int{0}, Shape{3}, []uint64{0, 1, 2}}, {"A.Min(1)", Uint64, Shape{2, 3}, []int{1}, Shape{2}, []uint64{0, 3}}, {"A.Min(0,1)", Uint64, Shape{2, 3}, []int{0, 1}, ScalarShape(), uint64(0)}, {"A.Min(1,0)", Uint64, Shape{2, 3}, []int{1, 0}, ScalarShape(), uint64(0)}, {"3T.Min(1,2)", Uint64, Shape{2, 3, 4}, []int{1, 2}, Shape{2}, []uint64{0, 12}}, + {"4T.Min()", Uint64, Shape{2, 2, 2, 2}, []int{}, ScalarShape(), uint64(0)}, + {"4T.Min(1,3)", Uint64, Shape{2, 2, 2, 2}, []int{1, 3}, Shape{2, 2}, []uint64{0, 2, 8, 10}}, + {"4T.Min(0, 2, 3)", Uint64, Shape{2, 2, 2, 2}, []int{0, 2, 3}, Shape{2}, []uint64{0, 4}}, {"common case: T.Min() for float32", Float32, Shape{2, 3}, []int{}, ScalarShape(), float32(0)}, {"A.Min(0)", Float32, Shape{2, 3}, []int{0}, Shape{3}, []float32{0, 1, 2}}, {"A.Min(1)", Float32, Shape{2, 3}, []int{1}, Shape{2}, []float32{0, 3}}, {"A.Min(0,1)", Float32, Shape{2, 3}, []int{0, 1}, ScalarShape(), float32(0)}, {"A.Min(1,0)", Float32, Shape{2, 3}, []int{1, 0}, ScalarShape(), float32(0)}, {"3T.Min(1,2)", Float32, Shape{2, 3, 4}, []int{1, 2}, Shape{2}, []float32{0, 12}}, + {"4T.Min()", Float32, Shape{2, 2, 2, 2}, []int{}, ScalarShape(), float32(0)}, + {"4T.Min(1,3)", Float32, Shape{2, 2, 2, 2}, []int{1, 3}, Shape{2, 2}, []float32{0, 2, 8, 10}}, + {"4T.Min(0, 2, 3)", Float32, Shape{2, 2, 2, 2}, []int{0, 2, 3}, Shape{2}, []float32{0, 4}}, {"common case: T.Min() for float64", Float64, Shape{2, 3}, []int{}, ScalarShape(), float64(0)}, {"A.Min(0)", Float64, Shape{2, 3}, []int{0}, Shape{3}, []float64{0, 1, 2}}, {"A.Min(1)", Float64, Shape{2, 3}, []int{1}, Shape{2}, []float64{0, 3}}, {"A.Min(0,1)", Float64, Shape{2, 3}, []int{0, 1}, ScalarShape(), float64(0)}, {"A.Min(1,0)", Float64, Shape{2, 3}, []int{1, 0}, ScalarShape(), float64(0)}, {"3T.Min(1,2)", Float64, Shape{2, 3, 4}, []int{1, 2}, Shape{2}, []float64{0, 12}}, + {"4T.Min()", Float64, Shape{2, 2, 2, 2}, []int{}, ScalarShape(), float64(0)}, + {"4T.Min(1,3)", Float64, Shape{2, 2, 2, 2}, []int{1, 3}, Shape{2, 2}, []float64{0, 2, 8, 10}}, + {"4T.Min(0, 2, 3)", Float64, Shape{2, 2, 2, 2}, []int{0, 2, 3}, Shape{2}, []float64{0, 4}}, } func TestDense_Min(t *testing.T) { diff --git a/dense_selbyidx_test.go b/dense_selbyidx_test.go new file mode 100644 index 0000000..98d309a --- /dev/null +++ b/dense_selbyidx_test.go @@ -0,0 +1,133 @@ +package tensor + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +type selByIndicesTest struct { + Name string + Data interface{} + Shape Shape + Indices []int + Axis int + WillErr bool + + Correct interface{} + CorrectShape Shape +} + +var selByIndicesTests = []selByIndicesTest{ + // {Name: "Basic", Data: Range(Float64, 0, 4), Shape: Shape{2, 2}, Indices: []int{0, 1}, Axis: 0, WillErr: false, + // Correct: []float64{0, 1, 2, 3}, CorrectShape: Shape{2, 2}, + // }, + // {Name: "3-tensor, axis 0", Data: Range(Float64, 0, 24), Shape: Shape{3, 2, 4}, Indices: []int{1, 1}, Axis: 0, WillErr: false, + // Correct: []float64{8, 9, 10, 11, 12, 13, 14, 15, 8, 9, 10, 11, 12, 13, 14, 15}, CorrectShape: Shape{2, 2, 4}}, + + // {Name: "3-tensor, axis 1", Data: Range(Float64, 0, 24), Shape: Shape{3, 2, 4}, Indices: []int{1, 1}, Axis: 1, WillErr: false, + // Correct: []float64{4, 5, 6, 7, 4, 5, 6, 7, 12, 13, 14, 15, 12, 13, 14, 15, 20, 21, 22, 23, 20, 21, 22, 23}, CorrectShape: Shape{3, 2, 4}}, + + // {Name: "3-tensor, axis 2", Data: Range(Float64, 0, 24), Shape: Shape{3, 2, 4}, Indices: []int{1, 1}, Axis: 2, WillErr: false, + // Correct: []float64{1, 1, 5, 5, 9, 9, 13, 13, 17, 17, 21, 21}, CorrectShape: Shape{3, 2, 2}}, + + // {Name: "Vector, axis 0", Data: Range(Int, 0, 5), Shape: Shape{5}, Indices: []int{1, 1}, Axis: 0, WillErr: false, + // Correct: []int{1, 1}, CorrectShape: Shape{2}}, + + {Name: "Vector, axis 1", Data: Range(Int, 0, 5), Shape: Shape{5}, Indices: []int{1, 1}, Axis: 1, WillErr: true, + Correct: []int{1, 1}, CorrectShape: Shape{2}}, + // {Name: "(4,2) Matrix, with (10) indices", Data: Range(Float32, 0, 8), Shape: Shape{4, 2}, Indices: []int{1, 1, 1, 1, 0, 2, 2, 2, 2, 0}, Axis: 0, WillErr: false, + // Correct: []float32{2, 3, 2, 3, 2, 3, 2, 3, 0, 1, 4, 5, 4, 5, 4, 5, 4, 5, 0, 1}, CorrectShape: Shape{10, 2}}, + // {Name: "(2,1) Matrx (colvec) with (10) indices", Data: Range(Float64, 0, 2), Shape: Shape{2, 1}, Indices: []int{1, 1, 1, 1, 1, 1, 1, 1, 1, 1}, Axis: 0, WillErr: false, + // Correct: []float64{1, 1, 1, 1, 1, 1, 1, 1, 1, 1}, CorrectShape: Shape{10}, + // }, +} + +func TestDense_SelectByIndices(t *testing.T) { + assert := assert.New(t) + for i, tc := range selByIndicesTests { + T := New(WithShape(tc.Shape...), WithBacking(tc.Data)) + indices := New(WithBacking(tc.Indices)) + ret, err := ByIndices(T, indices, tc.Axis) + if checkErr(t, tc.WillErr, err, tc.Name, i) { + continue + } + assert.Equal(tc.Correct, ret.Data()) + assert.True(tc.CorrectShape.Eq(ret.Shape())) + } +} + +var selByIndicesBTests = []struct { + selByIndicesTest + + CorrectGrad interface{} + CorrectGradShape Shape +}{ + // Basic + { + CorrectGrad: []float64{1, 1, 1, 1}, + }, + // 3-tensor, axis 0 + { + CorrectGrad: []float64{0, 0, 0, 0, 0, 0, 0, 0, 2, 2, 2, 2, 2, 2, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0}, + }, + // 3-tensor, axis 1 + { + CorrectGrad: []float64{0, 0, 0, 0, 2, 2, 2, 2, 0, 0, 0, 0, 2, 2, 2, 2, 0, 0, 0, 0, 2, 2, 2, 2}, + }, + // 3-tensor, axis 2 + { + CorrectGrad: []float64{0, 2, 0, 0, 0, 2, 0, 0, 0, 2, 0, 0, 0, 2, 0, 0, 0, 2, 0, 0, 0, 2, 0, 0}, + }, + // vector, axis 0 + { + CorrectGrad: []int{0, 2, 0, 0, 0}, + }, + // vector, axis 1 + { + CorrectGrad: []float32{4, 6, 8, 12, 8, 12, 0, 0}, + }, + // (4,2) Matrix with (10) indices + { + CorrectGrad: []float32{2, 2, 4, 4, 4, 4, 0, 0}, + }, + // (2, 1) Matrix (colvec) with (10) indices + { + CorrectGrad: []float64{0, 10}, + }, +} + +func init() { + // for i := range selByIndicesBTests { + // selByIndicesBTests[i].selByIndicesTest = selByIndicesTests[i] + // selByIndicesBTests[i].CorrectGradShape = selByIndicesTests[i].Shape + // } +} + +func TestDense_SelectByIndicesB(t *testing.T) { + + assert := assert.New(t) + for i, tc := range selByIndicesBTests { + T := New(WithShape(tc.Shape...), WithBacking(tc.Data)) + indices := New(WithBacking(tc.Indices)) + ret, err := ByIndices(T, indices, tc.Axis) + if checkErr(t, tc.WillErr, err, tc.Name, i) { + continue + } + outGrad := ret.Clone().(*Dense) + switch outGrad.Dtype() { + case Float64: + outGrad.Memset(1.0) + case Float32: + outGrad.Memset(float32(1.0)) + } + + grad, err := ByIndicesB(T, outGrad, indices, tc.Axis) + if checkErr(t, tc.WillErr, err, tc.Name, i) { + continue + } + assert.Equal(tc.CorrectGrad, grad.Data(), "%v - x:\n%v\nindices:\n%#v\ny:\n%#v\ngrad:\n%v", tc.Name, T, indices, ret, grad) + assert.True(tc.CorrectGradShape.Eq(grad.Shape()), "%v - Grad shape should be %v. Got %v instead.\n\nx:\n%v\nindices:\n%#v\ny:\n%#v\ngrad:\n%v", tc.Name, tc.CorrectGradShape, grad.Shape(), T, indices, ret, grad) + } + +} diff --git a/dense_softmax_test.go b/dense_softmax_test.go new file mode 100644 index 0000000..eaa68df --- /dev/null +++ b/dense_softmax_test.go @@ -0,0 +1,287 @@ +package tensor + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestSoftMax(t *testing.T) { + testCases := []struct { + fn func(x Tensor, axis int, opts ...FuncOpt) (Tensor, error) + x Tensor + axis int + expectedOutput interface{} + }{ + { + fn: LogSoftMax, + x: New( + Of(Float64), + WithShape(3, 4), + WithBacking([]float64{0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1}), + ), + axis: -1, + expectedOutput: []float64{-1.5425355294551628, -1.4425355294551627, -1.3425355294551626, -1.2425355294551628, -1.5425355294551628, -1.4425355294551627, -1.3425355294551626, -1.2425355294551628, -1.5425355294551628, -1.4425355294551627, -1.3425355294551629, -1.2425355294551628}, + }, + { + fn: LogSoftMax, + x: New( + Of(Float32), + WithShape(3, 4), + WithBacking([]float32{0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1}), + ), + axis: -1, + expectedOutput: []float32{-1.5425355294551628, -1.4425355294551627, -1.3425355294551626, -1.2425355294551628, -1.5425355294551628, -1.4425355294551627, -1.3425355294551626, -1.2425355294551628, -1.5425355294551628, -1.4425355294551627, -1.3425355294551629, -1.2425355294551628}, + }, + { + fn: LogSoftMax, + x: New( + Of(Float32), + WithShape(3, 2, 2), + WithBacking([]float32{0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1}), + ), + axis: -1, + expectedOutput: []float32{-0.7443967, -0.64439666, -0.7443967, -0.64439666, -0.7443967, -0.64439666, -0.7443966, -0.64439666, -0.7443966, -0.64439666, -0.7443967, -0.64439666}, + }, + { + fn: LogSoftMax, + x: New( + Of(Float64), + WithShape(3, 2, 2), + WithBacking([]float64{0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1}), + ), + axis: 1, + expectedOutput: []float64{-0.7981388693815918, -0.7981388693815918, -0.5981388693815918, -0.5981388693815919, -0.7981388693815918, -0.7981388693815918, -0.5981388693815919, -0.5981388693815919, -0.7981388693815918, -0.7981388693815918, -0.5981388693815919, -0.5981388693815918}, + }, + { + fn: SoftMax, + x: New( + Of(Float64), + WithShape(3, 2, 2), + WithBacking([]float64{0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1}), + ), + axis: 1, + expectedOutput: []float64{0.4501660026875221, 0.45016600268752205, 0.549833997312478, 0.5498339973124778, 0.45016600268752205, 0.45016600268752205, 0.5498339973124778, 0.5498339973124778, 0.45016600268752205, 0.4501660026875221, 0.5498339973124778, 0.549833997312478}, + }, + { + fn: SoftMax, + x: New( + Of(Float64), + WithShape(3, 2, 2), + WithBacking([]float64{0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1}), + ), + axis: -1, + expectedOutput: []float64{0.47502081252106, 0.52497918747894, 0.47502081252106, 0.5249791874789399, 0.47502081252106, 0.5249791874789399, 0.47502081252106, 0.5249791874789399, 0.47502081252106, 0.5249791874789399, 0.47502081252106, 0.52497918747894}, + }, + { + fn: SoftMax, + x: New( + Of(Float32), + WithShape(3, 4), + WithBacking([]float32{0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1}), + ), + axis: -1, + expectedOutput: []float32{0.21383822, 0.23632777, 0.2611826, 0.2886514, 0.21383823, 0.23632778, 0.2611826, 0.2886514, 0.21383822, 0.23632777, 0.26118258, 0.2886514}, + }, + { + fn: SoftMax, + x: New( + Of(Float64), + WithShape(3, 4), + WithBacking([]float64{0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1}), + ), + axis: -1, + expectedOutput: []float64{0.21383822, 0.23632777, 0.2611826, 0.2886514, 0.21383823, 0.23632778, 0.2611826, 0.2886514, 0.21383822, 0.23632777, 0.26118258, 0.2886514}, + }, + } + for i, tC := range testCases { + t.Run(fmt.Sprintf("Example #%d - %v %v", i+1, tC.x.Shape(), tC.x.Dtype()), func(t *testing.T) { + c := assert.New(t) + + output, err := tC.fn(tC.x, tC.axis) + t.Logf("output: %#v", output.Data()) + + c.NoError(err) + c.NotNil(output) + + c.Equal(tC.x.Shape(), output.Shape()) + c.InDeltaSlice(tC.expectedOutput, output.Data(), 1e-6) + }) + } +} + +func TestSoftMaxB(t *testing.T) { + testCases := []struct { + fn func(output, grad Tensor, axis int, opts ...FuncOpt) (Tensor, error) + output Tensor + grad Tensor + axis int + expectedOutput interface{} + }{ + { + fn: SoftMaxB, + output: New( + Of(Float64), + WithShape(3, 4), + WithBacking([]float64{0.21383822, 0.23632777, 0.2611826, 0.2886514, 0.21383823, 0.23632778, 0.2611826, 0.2886514, 0.21383822, 0.23632777, 0.26118258, 0.2886514}), + ), + grad: New( + Of(Float64), + WithShape(3, 4), + WithBacking([]float64{0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08, 0.09, 0.10, 0.11, 0.12}), + ), + axis: -1, + expectedOutput: []float64{-0.003474116568224552, -0.0014762147035963322, 0.0009803563066858392, 0.00396997522759976, -0.003474116880376028, -0.001476214931490494, 0.0009803561238580223, 0.003969975025543781, -0.0034741159267098936, -0.0014762139946130218, 0.0009803570151630109, 0.003969976093553957}, + }, + { + fn: LogSoftMaxB, + output: New( + Of(Float64), + WithShape(3, 4), + WithBacking([]float64{-1.5425355294551628, -1.4425355294551627, -1.3425355294551626, -1.2425355294551628, -1.5425355294551628, -1.4425355294551627, -1.3425355294551626, -1.2425355294551628, -1.5425355294551628, -1.4425355294551627, -1.3425355294551629, -1.2425355294551628}), + ), + grad: New( + Of(Float64), + WithShape(3, 4), + WithBacking([]float64{0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08, 0.09, 0.10, 0.11, 0.12}), + ), + axis: -1, + expectedOutput: []float64{-0.011383822036598441, -0.003632778232153768, 0.0038817407844924366, 0.01113485948425977, -0.005597937295155945, -0.001445223403599799, 0.0020925260396803457, 0.004950634659075405, 0.00018794744628654992, 0.0007423314249541871, 0.00030331129486827163, -0.0012335901661089598}, + }, + { + fn: SoftMaxB, + output: New( + Of(Float64), + WithShape(3, 2, 2), + WithBacking([]float64{0.47502081252106, 0.52497918747894, 0.47502081252106, 0.5249791874789399, 0.47502081252106, 0.5249791874789399, 0.47502081252106, 0.5249791874789399, 0.47502081252106, 0.5249791874789399, 0.47502081252106, 0.52497918747894}), + ), + grad: New( + Of(Float64), + WithShape(3, 2, 2), + WithBacking([]float64{0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08, 0.09, 0.10, 0.11, 0.12}), + ), + axis: -1, + expectedOutput: []float64{-0.002493760401928919, 0.0024937604019289205, -0.0024937604019289183, 0.002493760401928922, -0.002493760401928915, 0.002493760401928922, -0.002493760401928912, 0.0024937604019289253, -0.0024937604019289183, 0.0024937604019289253, -0.0024937604019289183, 0.0024937604019289183}, + }, + { + fn: SoftMaxB, + output: New( + Of(Float64), + WithShape(3, 2, 2), + WithBacking([]float64{0.4501660026875221, 0.45016600268752205, 0.549833997312478, 0.5498339973124778, 0.45016600268752205, 0.45016600268752205, 0.5498339973124778, 0.5498339973124778, 0.45016600268752205, 0.4501660026875221, 0.5498339973124778, 0.549833997312478}), + ), + grad: New( + Of(Float64), + WithShape(3, 2, 2), + WithBacking([]float64{0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08, 0.09, 0.10, 0.11, 0.12}), + ), + axis: 1, + expectedOutput: []float64{-0.004950331454237199, -0.004950331454237198, 0.004950331454237199, 0.0049503314542372, -0.004950331454237196, -0.004950331454237193, 0.004950331454237203, 0.0049503314542372065, -0.004950331454237193, -0.0049503314542372, 0.0049503314542372065, 0.004950331454237193}, + }, + { + fn: LogSoftMaxB, + output: New( + Of(Float64), + WithShape(3, 2, 2), + WithBacking([]float64{-0.7981388693815918, -0.7981388693815918, -0.5981388693815918, -0.5981388693815919, -0.7981388693815918, -0.7981388693815918, -0.5981388693815919, -0.5981388693815919, -0.7981388693815918, -0.7981388693815918, -0.5981388693815919, -0.5981388693815918}), + ), + grad: New( + Of(Float64), + WithShape(3, 2, 2), + WithBacking([]float64{0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08, 0.09, 0.10, 0.11, 0.12}), + ), + axis: 1, + expectedOutput: []float64{-0.008006640107500884, -0.007009960161251325, 0.00800664010750088, 0.007009960161251332, -0.004019920322502654, -0.003023240376253103, 0.004019920322502661, 0.0030232403762530968, -3.32005375044292e-05, 0.0009634794087451421, 3.320053750442642e-05, -0.0009634794087451543}, + }, + { + fn: LogSoftMaxB, + output: New( + Of(Float32), + WithShape(3, 2, 2), + WithBacking([]float32{-0.7981388693815918, -0.7981388693815918, -0.5981388693815918, -0.5981388693815919, -0.7981388693815918, -0.7981388693815918, -0.5981388693815919, -0.5981388693815919, -0.7981388693815918, -0.7981388693815918, -0.5981388693815919, -0.5981388693815918}), + ), + grad: New( + Of(Float32), + WithShape(3, 2, 2), + WithBacking([]float32{0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08, 0.09, 0.10, 0.11, 0.12}), + ), + axis: 1, + expectedOutput: []float64{-0.008006640107500884, -0.007009960161251325, 0.00800664010750088, 0.007009960161251332, -0.004019920322502654, -0.003023240376253103, 0.004019920322502661, 0.0030232403762530968, -3.32005375044292e-05, 0.0009634794087451421, 3.320053750442642e-05, -0.0009634794087451543}, + }, + { + fn: SoftMaxB, + output: New( + Of(Float32), + WithShape(3, 2, 2), + WithBacking([]float32{0.4501660026875221, 0.45016600268752205, 0.549833997312478, 0.5498339973124778, 0.45016600268752205, 0.45016600268752205, 0.5498339973124778, 0.5498339973124778, 0.45016600268752205, 0.4501660026875221, 0.5498339973124778, 0.549833997312478}), + ), + grad: New( + Of(Float32), + WithShape(3, 2, 2), + WithBacking([]float32{0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08, 0.09, 0.10, 0.11, 0.12}), + ), + axis: 1, + expectedOutput: []float32{-0.004950331454237199, -0.004950331454237198, 0.004950331454237199, 0.0049503314542372, -0.004950331454237196, -0.004950331454237193, 0.004950331454237203, 0.0049503314542372065, -0.004950331454237193, -0.0049503314542372, 0.0049503314542372065, 0.004950331454237193}, + }, + { + fn: SoftMaxB, + output: New( + Of(Float32), + WithShape(3, 4), + WithBacking([]float32{0.21383822, 0.23632777, 0.2611826, 0.2886514, 0.21383823, 0.23632778, 0.2611826, 0.2886514, 0.21383822, 0.23632777, 0.26118258, 0.2886514}), + ), + grad: New( + Of(Float64), + WithShape(3, 4), + WithBacking([]float32{0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08, 0.09, 0.10, 0.11, 0.12}), + ), + axis: -1, + expectedOutput: []float32{-0.003474116568224552, -0.0014762147035963322, 0.0009803563066858392, 0.00396997522759976, -0.003474116880376028, -0.001476214931490494, 0.0009803561238580223, 0.003969975025543781, -0.0034741159267098936, -0.0014762139946130218, 0.0009803570151630109, 0.003969976093553957}, + }, + { + fn: LogSoftMaxB, + output: New( + Of(Float64), + WithShape(3, 4), + WithBacking([]float32{-1.5425355294551628, -1.4425355294551627, -1.3425355294551626, -1.2425355294551628, -1.5425355294551628, -1.4425355294551627, -1.3425355294551626, -1.2425355294551628, -1.5425355294551628, -1.4425355294551627, -1.3425355294551629, -1.2425355294551628}), + ), + grad: New( + Of(Float64), + WithShape(3, 4), + WithBacking([]float32{0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08, 0.09, 0.10, 0.11, 0.12}), + ), + axis: -1, + expectedOutput: []float32{-0.011383822036598441, -0.003632778232153768, 0.0038817407844924366, 0.01113485948425977, -0.005597937295155945, -0.001445223403599799, 0.0020925260396803457, 0.004950634659075405, 0.00018794744628654992, 0.0007423314249541871, 0.00030331129486827163, -0.0012335901661089598}, + }, + { + fn: SoftMaxB, + output: New( + Of(Float64), + WithShape(3, 2, 2), + WithBacking([]float32{0.47502081252106, 0.52497918747894, 0.47502081252106, 0.5249791874789399, 0.47502081252106, 0.5249791874789399, 0.47502081252106, 0.5249791874789399, 0.47502081252106, 0.5249791874789399, 0.47502081252106, 0.52497918747894}), + ), + grad: New( + Of(Float64), + WithShape(3, 2, 2), + WithBacking([]float32{0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08, 0.09, 0.10, 0.11, 0.12}), + ), + axis: -1, + expectedOutput: []float32{-0.002493760401928919, 0.0024937604019289205, -0.0024937604019289183, 0.002493760401928922, -0.002493760401928915, 0.002493760401928922, -0.002493760401928912, 0.0024937604019289253, -0.0024937604019289183, 0.0024937604019289253, -0.0024937604019289183, 0.0024937604019289183}, + }, + } + for i, tC := range testCases { + t.Run(fmt.Sprintf("Example #%d - %v %v", i+1, tC.output.Shape(), tC.output.Dtype()), func(t *testing.T) { + c := assert.New(t) + + dx, err := tC.fn(tC.output, tC.grad, tC.axis) + t.Logf("output: %#v", tC.output.Data()) + + c.NoError(err) + c.NotNil(dx) + + c.Equal(tC.output.Shape(), dx.Shape()) + c.InDeltaSlice(tC.expectedOutput, dx.Data(), 1e-6) + }) + } +} diff --git a/dense_svd_test.go b/dense_svd_test.go index 36e4e16..282868b 100644 --- a/dense_svd_test.go +++ b/dense_svd_test.go @@ -1,6 +1,7 @@ package tensor import ( + "fmt" "testing" "github.com/pkg/errors" @@ -103,6 +104,27 @@ func testSVD(T, T2, s, u, v *Dense, t string, i int) (err error) { return nil } +func ExampleDense_SVD() { + T := New( + WithShape(4, 5), + WithBacking([]float64{1, 0, 0, 0, 2, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0}), + ) + _, u, _, _ := T.SVD(true, true) + uT := u.Clone().(*Dense) + uT.T() + eye, err := u.MatMul(uT) + fmt.Println(eye) + fmt.Println(err) + + // Output: + // ⎡1 0 0 0⎤ + // ⎢0 1 0 0⎥ + // ⎢0 0 1 0⎥ + // ⎣0 0 0 1⎦ + // + // +} + func TestDense_SVD(t *testing.T) { var T, T2, s, u, v *Dense var err error @@ -134,7 +156,6 @@ func TestDense_SVD(t *testing.T) { t.Errorf("Expected v = %v. Got %v instead", stts.correctVData, v.Data()) } } - // standard tests for i, stfs := range svdtestsFull { T = New(WithShape(stfs...), WithBacking(Random(Float64, stfs.TotalSize()))) @@ -143,14 +164,14 @@ func TestDense_SVD(t *testing.T) { // full if s, u, v, err = T.SVD(true, true); err != nil { t.Error(err) + fmt.Println(err) continue } - if err = testSVD(T, T2, s, u, v, "full", i); err != nil { t.Error(err) + fmt.Println(err) continue } - // thin if s, u, v, err = T.SVD(true, false); err != nil { t.Error(err) @@ -183,8 +204,8 @@ func TestDense_SVD(t *testing.T) { if !allClose(s.Data(), svd.Values(nil), closeenoughf64) { t.Errorf("Singular value mismatch between Full and None decomposition. Expected %v. Got %v instead", svd.Values(nil), s.Data()) } - } + } // this is illogical T = New(Of(Float64), WithShape(2, 2)) if _, _, _, err = T.SVD(false, true); err == nil { diff --git a/dense_test.go b/dense_test.go index bcb9ba3..d7e81e6 100644 --- a/dense_test.go +++ b/dense_test.go @@ -5,7 +5,6 @@ import ( "testing" "testing/quick" "time" - "unsafe" "github.com/stretchr/testify/assert" ) @@ -90,27 +89,52 @@ func TestFromScalar(t *testing.T) { assert.Equal(t, []float64{3.14}, data) } -func TestFromMemory(t *testing.T) { - // dummy memory - this could be an externally malloc'd memory, or a mmap'ed file. - // but here we're just gonna let Go manage memory. - s := make([]float64, 100) - ptr := uintptr(unsafe.Pointer(&s[0])) - size := uintptr(100 * 8) - - T := New(Of(Float32), WithShape(50, 4), FromMemory(ptr, size)) - if len(T.Float32s()) != 200 { - t.Error("expected 200 Float32s") - } - assert.Equal(t, make([]float32, 200), T.Data()) - assert.True(t, T.IsManuallyManaged(), "Unamanged %v |%v | q: %v", ManuallyManaged, T.flag, (T.flag>>ManuallyManaged)&MemoryFlag(1)) - - fail := func() { New(FromMemory(ptr, size), Of(Float32)) } - assert.Panics(t, fail, "Expected bad New() call to panic") -} - func Test_recycledDense(t *testing.T) { T := recycledDense(Float64, ScalarShape()) assert.Equal(t, float64(0), T.Data()) assert.Equal(t, StdEng{}, T.e) assert.Equal(t, StdEng{}, T.oe) } + +func TestDense_unsqueeze(t *testing.T) { + assert := assert.New(t) + T := New(WithShape(3, 3, 2), WithBacking([]float64{ + 1, 2, 3, 4, 5, 6, + 60, 50, 40, 30, 20, 10, + 100, 200, 300, 400, 500, 600, + })) + + if err := T.unsqueeze(0); err != nil { + t.Fatal(err) + } + + assert.True(T.Shape().Eq(Shape{1, 3, 3, 2})) + assert.Equal([]int{6, 6, 2, 1}, T.Strides()) // if you do shapes.CalcStrides() it'd be {18,6,2,1} + + // reset + T.Reshape(3, 3, 2) + + if err := T.unsqueeze(1); err != nil { + t.Fatal(err) + } + assert.True(T.Shape().Eq(Shape{3, 1, 3, 2})) + assert.Equal([]int{6, 2, 2, 1}, T.Strides()) + + // reset + T.Reshape(3, 3, 2) + if err := T.unsqueeze(2); err != nil { + t.Fatal(err) + } + t.Logf("%v", T) + assert.True(T.Shape().Eq(Shape{3, 3, 1, 2})) + assert.Equal([]int{6, 2, 1, 1}, T.Strides()) + + // reset + T.Reshape(3, 3, 2) + if err := T.unsqueeze(3); err != nil { + t.Fatal(err) + } + t.Logf("%v", T) + assert.True(T.Shape().Eq(Shape{3, 3, 2, 1})) + assert.Equal([]int{6, 2, 1, 1}, T.Strides()) +} diff --git a/dense_views.go b/dense_views.go index d56ee8c..ab3c537 100644 --- a/dense_views.go +++ b/dense_views.go @@ -3,13 +3,42 @@ package tensor // a View is a *Tensor with customized strides. The reason for not splitting them up into different types is complicated // this file contains all the methods that deals with Views +var _ View = DenseView{} + +// Dense +type DenseView struct { + *Dense +} + +// RequiresIterator returns true if an iterator is required to read the data in the correct fashion. +func (t DenseView) RequiresIterator() bool { + if t.len() == 1 { + return false + } + // non continuous slice, transpose, or masked. If it's a slice and contiguous, then iterator is not required + if !t.o.IsContiguous() || !t.old.IsZero() || t.IsMasked() { + return true + } + return false +} + +// IsView indicates if the Tensor is a view of another (typically from slicing) +func (t DenseView) IsView() bool { + return t.viewOf != 0 +} + +// IsMaterializeable indicates if the Tensor is materializable - if it has either gone through some transforms or slicing +func (t DenseView) IsMaterializable() bool { + return t.viewOf != 0 || !t.old.IsZero() +} + // Materialize takes a view, copies its data and puts it in a new *Tensor. -func (t *Dense) Materialize() Tensor { +func (t DenseView) Materialize() Tensor { if !t.IsMaterializable() { - return t + return t.Dense } - retVal := recycledDense(t.t, t.shape.Clone()) + retVal := recycledDense(t.t, t.shape.Clone(), WithEngine(t.e)) copyDenseIter(retVal, t, nil, nil) retVal.e = t.e retVal.oe = t.oe diff --git a/divmod.s b/divmod_amd64.s similarity index 100% rename from divmod.s rename to divmod_amd64.s diff --git a/engine.go b/engine.go index af56f6b..2a115e7 100644 --- a/engine.go +++ b/engine.go @@ -1,7 +1,9 @@ package tensor import ( - "unsafe" + "context" + + "gorgonia.org/dtype" ) // Memory is a representation of memory of the value. @@ -13,7 +15,6 @@ import ( type Memory interface { Uintptr() uintptr MemSize() uintptr - Pointer() unsafe.Pointer } // Engine is a representation of an execution engine. @@ -29,7 +30,16 @@ type Engine interface { WorksWith(order DataOrder) bool // WorksWith returns true if the data order can be directly worked with } -type standardEngine interface { +// StandardEngine is any engine that wraps a StdEng{}. +type StandardEngine interface { + StandardEngine2 + + // anything that wraps StdEng will contain the following interfaces: + arrayMaker +} + +// StandardEngine2 is any engine that implements the basic operations of a standard engine. +type StandardEngine2 interface { Engine Adder @@ -49,6 +59,8 @@ type standardEngine interface { Gter Gteer ElEqer + MinBetweener + MaxBetweener // Anything that returns interface{} cannot be added here because they will likely have additional // optimized versions of the functions for types. @@ -56,34 +68,50 @@ type standardEngine interface { } type arrayMaker interface { - makeArray(arr *array, t Dtype, size int) + makeArray(arr *array, t dtype.Dtype, size int) +} + +// contexter is any engine (or type) that returns the current context. +type contexter interface { + Context() context.Context +} + +// NonStdEngine are any engines that do not allocate using the default built in allocator +type NonStdEngine interface { + NonStdAlloc() // noop } /* Data Agnostic Execution Engine Methods */ // Transposer is any engine that can perform an unsafe transpose of a tensor. type Transposer interface { - Transpose(t Tensor, expStrides []int) error + Transpose(ctx context.Context, t Tensor, expStrides []int) error } // Concater is any enegine that can concatenate multiple Tensors together type Concater interface { - Concat(t Tensor, axis int, others ...Tensor) (Tensor, error) + Concat(ctx context.Context, t Tensor, axis int, others ...Tensor) (Tensor, error) } // Stacker is any engine that can stack multiple Tenosrs along an axis type Stacker interface { - Stack(t Tensor, axis int, others ...Tensor) (Tensor, error) + Stack(ctx context.Context, t Tensor, axis int, others ...Tensor) (Tensor, error) } // DenseStacker is any engine that can stack DenseTensors along an axis. This is a specialization of Stacker. type DenseStacker interface { - StackDense(t DenseTensor, axis int, others ...DenseTensor) (retVal DenseTensor, err error) + StackDense(ctx context.Context, t DenseTensor, axis int, others ...DenseTensor) (retVal DenseTensor, err error) } // Repeater is any engine that can repeat values along the given axis. type Repeater interface { - Repeat(t Tensor, axis int, repeats ...int) (Tensor, error) + Repeat(ctx context.Context, t Tensor, axis int, repeats ...int) (Tensor, error) + RepeatReuse(ctx context.Context, t Tensor, reuse Tensor, axis int, repeeats ...int) (Tensor, error) +} + +// Diager is any engine that can return a tensor that only contains the diagonal values of the input +type Diager interface { + Diag(ctx context.Context, a Tensor) (Tensor, error) } /* NUMBER INTERFACES @@ -151,47 +179,61 @@ type Moder interface { ModScalar(a Tensor, b interface{}, leftTensor bool, opts ...FuncOpt) (Tensor, error) } +// MinBetweener is any engine that can perform an elementwise min=between. +type MinBetweener interface { + MinBetween(a, b Tensor, opts ...FuncOpt) (Tensor, error) + + MinBetweenScalar(a Tensor, b interface{}, leftTensor bool, opts ...FuncOpt) (Tensor, error) +} + +// MaxBetweener is any engine that can perform an elementwise ma>> np.diag(np.diag(np.array([1,2,3]))) +// array([1,2,3]) +func Example_asDenseDiag() { + T := New(WithShape(3), WithBacking([]int{1, 2, 3})) + T1 := New(AsDenseDiag(T)) + fmt.Printf("T1:\n%v", T1) + + T2 := New(AsDenseDiag([]float64{3.14, 6.28, 11111})) + fmt.Printf("T2:\n%v", T2) + // Output: + // T1: + //⎡1 0 0⎤ + //⎢0 2 0⎥ + //⎣0 0 3⎦ + // T2: + // ⎡ 3.14 0 0⎤ + // ⎢ 0 6.28 0⎥ + // ⎣ 0 0 11111⎦ +} diff --git a/flags.go b/flags.go index dfe551e..5cc0bae 100644 --- a/flags.go +++ b/flags.go @@ -13,8 +13,13 @@ const ( // A data can either be Contiguous (0) or NonContiguous (2). // The way DataOrder was designed causes the default to be Contiguous. NonContiguous + + // Transposed indicates that the data has been transposed + Transposed ) +var dataOrderNames = []rune("NonContiguous, RowMajorᵀNonContiguous, ColMajorᵀ") + // MakeDataOrder makes a data order. Typical examples: // MakeDataOrder(DataOrder(0)) // Row Major, contiguous // MakeDataOrder(NonContiguous // Row Major, non-contiguous @@ -30,13 +35,48 @@ func MakeDataOrder(fs ...DataOrder) (retVal DataOrder) { return } -func (f DataOrder) isColMajor() bool { return (f & ColMajor) != 0 } -func (f DataOrder) isRowMajor() bool { return !f.isColMajor() } -func (f DataOrder) isContiguous() bool { return !f.isNotContiguous() } -func (f DataOrder) isNotContiguous() bool { return (f & NonContiguous) != 0 } +// IsColMajor returns true if the data order describes a col-major data +func (f DataOrder) IsColMajor() bool { return (f & ColMajor) != 0 } + +// IsRowMajor returns true if the data order describes a row-major data +func (f DataOrder) IsRowMajor() bool { return !f.IsColMajor() } + +// IsContiguous returns true if the data order describes a contiguous data. +func (f DataOrder) IsContiguous() bool { return !f.IsNotContiguous() } + +// IsNotContiguous returns true if the data order describes a noncontiguous data. +func (f DataOrder) IsNotContiguous() bool { return (f & NonContiguous) != 0 } + +// IsTransposed returns true if the data order describes whether the data has been tranposed (but not moved) +func (f DataOrder) IsTransposed() bool { return (f & Transposed) != 0 } + func (f DataOrder) toggleColMajor() DataOrder { return f ^ (ColMajor) } -func (f DataOrder) hasSameOrder(other DataOrder) bool { - return (f.isColMajor() && other.isColMajor()) || (f.isRowMajor() && other.isRowMajor()) + +func (f DataOrder) clearTransposed() DataOrder { return f &^ (Transposed) } + +// HasSameOrder returns true if both data orders are the same (either both are ColMajor or both are RowMajor) +func (f DataOrder) HasSameOrder(other DataOrder) bool { + return (f.IsColMajor() && other.IsColMajor()) || (f.IsRowMajor() && other.IsRowMajor()) +} + +func (f DataOrder) String() string { + var start, end int + if f.IsRowMajor() { + end = 23 + if f.IsContiguous() { + start = 3 + } + } else { + end = 47 + start = 24 + if f.IsContiguous() { + start = 27 + } + } + if f.IsTransposed() { + end++ + } + return string(dataOrderNames[start:end]) } // Triangle is a flag representing the "triangle"ness of a matrix @@ -58,6 +98,8 @@ const ( // ManuallyManaged indicates that the memory is managed by something else. Any Tensor with // manually managed memory will not be returned to the pool. ManuallyManaged + // IsOverallocated indicates that the memory for a given tensor is overallocated (i.e. the size-in-use is smaller than the size allocated) + IsOverallocated ) func MakeMemoryFlag(fs ...MemoryFlag) (retVal MemoryFlag) { @@ -73,52 +115,4 @@ func MakeMemoryFlag(fs ...MemoryFlag) (retVal MemoryFlag) { func (f MemoryFlag) nativelyAccessible() bool { return !((f & NativelyInaccessible) != 0) } func (f MemoryFlag) manuallyManaged() bool { return (f & ManuallyManaged) != 0 } - -// OpOpt are the options used to call ops -type OpOpt struct { - reuse Tensor - incr Tensor - unsafe bool - same bool - t Dtype -} - -// ParseFuncOpts parses a list of FuncOpt into a single unified method call structure. -func ParseFuncOpts(opts ...FuncOpt) *OpOpt { - retVal := borrowOpOpt() - for _, opt := range opts { - opt(retVal) - } - return retVal -} - -// Incr returns the tensor to be incremented in the call. Can be nil. -func (fo *OpOpt) Incr() Tensor { return fo.incr } - -// Reuse returns the tensor to be reused in the call. Can be nil. -func (fo *OpOpt) Reuse() Tensor { return fo.reuse } - -// IncReuse returns whether a reuse tensor is to be used as the incr Tensor -func (fo *OpOpt) IncrReuse() (Tensor, bool) { - if fo.incr != nil { - return fo.incr, true - } - return fo.reuse, false -} - -// Safe signals if the op is to be done safely -func (fo *OpOpt) Safe() bool { return !fo.unsafe } - -// Same signals if the op is to return the same type as its inputs -func (fo *OpOpt) Same() bool { return fo.same } - -// As returns the dtype of the return value of the method call. -// For example: -// a.Lt(b, As(Bool)) -// indicates that the result of the `Lt()` should be a Tensor of Bool. -// -// Another example: -// a.Add(b, As(Int)) -// indicates that the result of `Add()` should be converted to a Tensor of Int. -// Note that this function is not yet supported in most operations. -func (fo *OpOpt) As() Dtype { return fo.t } +func (f MemoryFlag) isOverallocated() bool { return (f & IsOverallocated) != 0 } diff --git a/flags_test.go b/flags_test.go index 98a8772..26d10e8 100644 --- a/flags_test.go +++ b/flags_test.go @@ -1,63 +1,90 @@ -package tensor - -import "testing" - -func TestMemoryFlag(t *testing.T) { - var defaultFlag MemoryFlag - if defaultFlag.manuallyManaged() || !defaultFlag.nativelyAccessible() { - t.Errorf("Something went wrong with the creation of flags") - } - - a := ManuallyManaged - if !a.manuallyManaged() { - t.Errorf("Expected ManuallyManaged to be true") - } - if !a.nativelyAccessible() { - t.Errorf("Expected ManuallyManaged to be nativelyAccessible") - } - - b := NativelyInaccessible - if b.manuallyManaged() { - t.Errorf("Expected NativelyInaccessible to not be manually managed") - } - if b.nativelyAccessible() { - t.Errorf("Expected NativelyInaccessible to be false %v", b.nativelyAccessible()) - } - - c := MakeMemoryFlag(ManuallyManaged, NativelyInaccessible) - if !c.manuallyManaged() { - t.Errorf("Expected c to be manually managed") - } - if c.nativelyAccessible() { - t.Errorf("Expected c to be natively inaccessible") - } -} - -func TestDataOrder(t *testing.T) { - var defaultFlag DataOrder - if defaultFlag.isColMajor() || defaultFlag.isNotContiguous() { - t.Errorf("Expected default flag to be row major and contiguous") - } - if !(defaultFlag.isRowMajor() && defaultFlag.isContiguous()) { - t.Errorf("Expected default flag to be row major and contiguous") - } - - cm := ColMajor - if cm.isRowMajor() { - t.Errorf("colMajor cannot be rowMajor") - } - if cm.isNotContiguous() { - t.Errorf("ColMajor by default is contiguous") - } - - // check toggle - rm := cm.toggleColMajor() - if rm.isColMajor() { - t.Errorf("toggled cm should be rm") - } - - cm = rm.toggleColMajor() - if cm.isRowMajor() { - t.Errorf("toggled rm should be cm") - } -} +package tensor + +import "testing" + +func TestMemoryFlag(t *testing.T) { + var defaultFlag MemoryFlag + if defaultFlag.manuallyManaged() || !defaultFlag.nativelyAccessible() { + t.Errorf("Something went wrong with the creation of flags") + } + + a := ManuallyManaged + if !a.manuallyManaged() { + t.Errorf("Expected ManuallyManaged to be true") + } + if !a.nativelyAccessible() { + t.Errorf("Expected ManuallyManaged to be nativelyAccessible") + } + + b := NativelyInaccessible + if b.manuallyManaged() { + t.Errorf("Expected NativelyInaccessible to not be manually managed") + } + if b.nativelyAccessible() { + t.Errorf("Expected NativelyInaccessible to be false %v", b.nativelyAccessible()) + } + + c := MakeMemoryFlag(ManuallyManaged, NativelyInaccessible) + if !c.manuallyManaged() { + t.Errorf("Expected c to be manually managed") + } + if c.nativelyAccessible() { + t.Errorf("Expected c to be natively inaccessible") + } +} + +func TestDataOrder(t *testing.T) { + var defaultFlag DataOrder + if defaultFlag.IsColMajor() || defaultFlag.IsNotContiguous() || defaultFlag.IsTransposed() { + t.Error("Expected default flag to be row major and contiguous and not transposed") + } + if !(defaultFlag.IsRowMajor() && defaultFlag.IsContiguous()) { + t.Error("Expected default flag to be row major and contiguous") + } + if defaultFlag.String() != "Contiguous, RowMajor" { + t.Errorf("Expected string is \"Contiguous, RowMajor\". Got %q", defaultFlag.String()) + } + + ncrm := MakeDataOrder(NonContiguous) + if ncrm.IsColMajor() || ncrm.IsContiguous() { + t.Error("Expected noncontiguous row major.") + } + if ncrm.String() != "NonContiguous, RowMajor" { + t.Errorf("Expected string is \"NonContiguous, RowMajor\". Got %q", defaultFlag.String()) + } + + cm := ColMajor + if cm.IsRowMajor() { + t.Error("colMajor cannot be rowMajor") + } + if cm.IsNotContiguous() { + t.Error("ColMajor by default is contiguous") + } + if cm.String() != "Contiguous, ColMajor" { + t.Errorf(`Expected string is "Contiguous, ColMajor". Got %q`, cm.String()) + } + + // check toggle + rm := cm.toggleColMajor() + if rm.IsColMajor() { + t.Errorf("toggled cm should be rm") + } + + cm = rm.toggleColMajor() + if cm.IsRowMajor() { + t.Errorf("toggled rm should be cm") + } + + transposed := MakeDataOrder(Transposed) + if !transposed.IsTransposed() { + t.Error("Expected transposed flag to be set") + } + if transposed.String() != "Contiguous, RowMajorᵀ" { + t.Errorf("Expected string is \"Contiguous, RowMajorᵀ\". Got %q", defaultFlag.String()) + } + untransposed := transposed.clearTransposed() + if untransposed != defaultFlag { + t.Error("Expected default flag after untransposing") + } + +} diff --git a/funcopts.go b/funcopts.go new file mode 100644 index 0000000..65e3d4b --- /dev/null +++ b/funcopts.go @@ -0,0 +1,153 @@ +package tensor + +import ( + "context" + + "gorgonia.org/dtype" +) + +// FuncOpt are optionals for calling Tensor functions. +// The `*opOpt` type is unexported, but it's methods are exported. +// This is intentional as use of the `*opOpt` is very specialized. +// See funcopts.go for more information. +type FuncOpt func(*opOpt) + +// WithIncr passes in a Tensor to be incremented. +func WithIncr(incr Tensor) FuncOpt { + f := func(opt *opOpt) { + opt.incr = incr + } + return f +} + +// WithReuse passes in a Tensor to be reused. +func WithReuse(reuse Tensor) FuncOpt { + f := func(opt *opOpt) { + opt.reuse = reuse + } + return f +} + +// UseSafe ensures that the operation is a safe operation (copies data, does not clobber). This is the default option for most methods and functions +func UseSafe() FuncOpt { + f := func(opt *opOpt) { + opt.unsafe = false + } + return f +} + +// UseUnsafe ensures that the operation is an unsafe operation - data will be clobbered, and operations performed inplace +func UseUnsafe() FuncOpt { + f := func(opt *opOpt) { + opt.unsafe = true + } + return f +} + +// AsSameType makes sure that the return Tensor is the same type as input Tensors. +func AsSameType() FuncOpt { + f := func(opt *opOpt) { + opt.same = true + } + return f +} + +// As makes sure that the the return Tensor is of the type specified. Currently only works for FromMat64 +func As(t dtype.Dtype) FuncOpt { + f := func(opt *opOpt) { + opt.t = t + } + return f +} + +// WithContext allows a function to be called with a given context +func WithContext(ctx context.Context) FuncOpt { + f := func(opt *opOpt) { + opt.ctx = ctx + } + return f +} + +// opOpt are the options used to call ops +type opOpt struct { + reuse Tensor + incr Tensor + unsafe bool + same bool + t dtype.Dtype + ctx context.Context +} + +// ParseFuncOpts parses a list of FuncOpt into a single unified method call structure. +func ParseFuncOpts(opts ...FuncOpt) *opOpt { + retVal := borrowOpOpt() + + for _, opt := range opts { + opt(retVal) + } + if retVal.ctx == nil { + retVal.ctx = context.Background() // default context - required for no panics. + } + return retVal +} + +// Incr returns the tensor to be incremented in the call. Can be nil. +func (fo *opOpt) Incr() Tensor { return fo.incr } + +// Reuse returns the tensor to be reused in the call. Can be nil. +func (fo *opOpt) Reuse() Tensor { return fo.reuse } + +// IncrReuse returns whether a reuse tensor is to be used as the incr Tensor +func (fo *opOpt) IncrReuse() (Tensor, bool) { + if fo.incr != nil { + return fo.incr, true + } + return fo.reuse, false +} + +// Safe signals if the op is to be done safely +func (fo *opOpt) Safe() bool { return !fo.unsafe } + +// Same signals if the op is to return the same type as its inputs +func (fo *opOpt) Same() bool { return fo.same } + +// As returns the dtype of the return value of the method call. +// For example: +// a.Lt(b, As(Bool)) +// indicates that the result of the `Lt()` should be a Tensor of Bool. +// +// Another example: +// a.Add(b, As(Int)) +// indicates that the result of `Add()` should be converted to a Tensor of Int. +// Note that this function is not yet supported in most operations. +func (fo *opOpt) As() dtype.Dtype { return fo.t } + +// Context returns a context.Context that may have been passed in as a function option. +func (fo *opOpt) Context() context.Context { return fo.ctx } + +// SetReuse allows the reuse parameter to be set. +func (fo *opOpt) SetReuse(reuse Tensor) { fo.reuse = reuse } + +// SetIncr allows the incr parameter to be set. +func (fo *opOpt) SetIncr(incr Tensor) { fo.incr = incr } + +// FuncOpts is the inverse of ParseFuncOpts. +func (fo *opOpt) FuncOpts() []FuncOpt { + retVal := make([]FuncOpt, 0, 4) + if fo.reuse != nil { + retVal = append(retVal, WithReuse(fo.reuse)) + } + if fo.incr != nil { + retVal = append(retVal, WithIncr(fo.incr)) + } + if fo.unsafe { + retVal = append(retVal, UseUnsafe()) + } + if fo.same { + retVal = append(retVal, AsSameType()) + } + if fo.t != (Dtype{}) { + retVal = append(retVal, As(fo.t)) + } + return retVal +} diff --git a/generic_utils.go b/generic_utils.go index 24310b5..9a44263 100644 --- a/generic_utils.go +++ b/generic_utils.go @@ -1,5 +1,3 @@ -// Code generated by genlib2. DO NOT EDIT. - package tensor import ( @@ -7,14 +5,17 @@ import ( "reflect" "github.com/pkg/errors" + "gorgonia.org/dtype" "gorgonia.org/vecf32" "gorgonia.org/vecf64" ) +// Code generated by genlib2. DO NOT EDIT. + // Range creates a ranged array with a given type. It panics if the Dtype is not supported or does not represent a naturally orderable type (strings, pointers etc) // Do note that the range algorithm is very simple, and simply does increments or decrements of 1. This means for floating point types // you're not able to create a range with a 0.1 increment step, and for complex number types, the imaginary part will always be 0i -func Range(dt Dtype, start, end int) interface{} { +func Range(dt dtype.Dtype, start, end int) interface{} { size := end - start incr := true if start > end { @@ -172,7 +173,7 @@ func Range(dt Dtype, start, end int) interface{} { // For complex Dtypes, the imaginary component will be 0. // // This function is only useful in cases where the randomness is not vital. -func Random(dt Dtype, size int) interface{} { +func Random(dt dtype.Dtype, size int) interface{} { r := rand.New(rand.NewSource(1337)) switch dt.Kind() { case reflect.Int: diff --git a/genlib2/agg1_body.go b/genlib2/agg1_body.go index 2ca9d96..85580d0 100644 --- a/genlib2/agg1_body.go +++ b/genlib2/agg1_body.go @@ -5,8 +5,8 @@ import "text/template" // level 1 aggregation (internal.E) templates const ( - eArithRaw = `as := isScalar(a) - bs := isScalar(b) + eArithRaw = `as := isScalar(a, t) + bs := isScalar(b, t) {{$name := .Name}} switch t { {{range .Kinds -}} @@ -25,18 +25,18 @@ const ( default: {{if and $isDiv $p}} err = {{end}} Vec{{$name}}{{short .}}(at, bt) } - return + return {{end -}} default: return errors.Errorf("Unsupported type %v for {{$name}}", t) } ` - eArithIncrRaw = `as := isScalar(a) - bs := isScalar(b) - is := isScalar(incr) + eArithIncrRaw = `as := isScalar(a, t) + bs := isScalar(b, t) + is := isScalar(incr, t) if ((as && !bs) || (bs && !as)) && is { - return errors.Errorf("Cannot increment on scalar increment. a: %d, b %d", a.Len(), b.Len()) + return errors.Errorf("Cannot increment on scalar increment. a: %d, b %d", a.TypedLen(t), b.TypedLen(t)) } {{$name := .Name}} switch t { @@ -60,14 +60,14 @@ const ( default: {{$name}}Incr{{short .}}(at, bt,it) } - return + return {{end -}} default: return errors.Errorf("Unsupported type %v for {{$name}}", t) } ` - eArithIterRaw = `as := isScalar(a) - bs := isScalar(b) + eArithIterRaw = `as := isScalar(a, t) + bs := isScalar(b, t) {{$name := .Name}} switch t { {{range .Kinds -}} @@ -87,16 +87,16 @@ const ( return {{end -}} default: - return errors.Errorf("Unsupported type %v for {{$name}}", t) + return errors.Errorf("Unsupported type %v for {{$name}}Iter", t) } ` - eArithIterIncrRaw = `as :=isScalar(a) - bs := isScalar(b) - is := isScalar(incr) + eArithIterIncrRaw = `as :=isScalar(a, t) + bs := isScalar(b, t) + is := isScalar(incr, t) if ((as && !bs) || (bs && !as)) && is { - return errors.Errorf("Cannot increment on a scalar increment. len(a): %d, len(b) %d", a.Len(), b.Len()) + return errors.Errorf("Cannot increment on a scalar increment. len(a): %d, len(b) %d", a.TypedLen(t), b.TypedLen(t)) } {{$name := .Name}} switch t { @@ -122,11 +122,34 @@ const ( } {{end -}} default: - return errors.Errorf("Unsupported type %v for {{$name}}", t) + return errors.Errorf("Unsupported type %v for {{$name}}IterIncr", t) + } + ` + + eArithRecvRaw = `as :=isScalar(a, t) + bs := isScalar(b, t) + rs := isScalar(recv, t) + + if ((as && !bs) || (bs && !as)) && rs { + return errors.Errorf("Cannot increment on a scalar increment. len(a): %d, len(b) %d", a.TypedLen(t), b.TypedLen(t)) + } + + {{$name := .Name}} + switch t{ + {{range .Kinds -}} + case {{reflectKind .}}: + at := a.{{sliceOf .}} + bt := b.{{sliceOf .}} + rt := recv.{{sliceOf .}} + {{$name}}Recv{{short .}}(at, bt, rt) + return + {{end -}} + default: + return errors.Errorf("Unsupported type %v for {{$name}}Recv", t) } ` - eMapRaw = `as := isScalar(a) + eMapRaw = `as := isScalar(a, t) switch t { {{range .Kinds -}} case {{reflectKind .}}: @@ -181,11 +204,11 @@ const ( Map{{short .}}(f0, at) } {{end -}} - + {{end -}} default: return errors.Errorf("Cannot map t of %v", t) - + } ` @@ -233,8 +256,8 @@ const ( } ` - eCmpSameRaw = `as := isScalar(a) - bs := isScalar(b) + eCmpSameRaw = `as := isScalar(a, t) + bs := isScalar(b, t) {{$name := .Name}} switch t { {{range .Kinds -}} @@ -252,20 +275,20 @@ const ( default: {{$name}}Same{{short .}}(at, bt) } - return + return {{end -}} {{end -}} default: return errors.Errorf("Unsupported type %v for {{$name}}", t) }` - eCmpBoolRaw = `as := isScalar(a) - bs := isScalar(b) - rs := isScalar(retVal) + eCmpBoolRaw = `as := isScalar(a, t) + bs := isScalar(b, t) + rs := isScalar(retVal, Bool) rt := retVal.Bools() if ((as && !bs) || (bs && !as)) && rs { - return errors.Errorf("retVal is a scalar. a: %d, b %d", a.Len(), b.Len()) + return errors.Errorf("retVal is a scalar. a: %d, b %d", a.TypedLen(t), b.TypedLen(t)) } {{$name := .Name}} @@ -285,15 +308,15 @@ const ( default: {{$name}}{{short .}}(at, bt, rt) } - return + return {{end -}} default: return errors.Errorf("Unsupported type %v for {{$name}}", t) } ` - eCmpSameIterRaw = `as := isScalar(a) - bs := isScalar(b) + eCmpSameIterRaw = `as := isScalar(a, t) + bs := isScalar(b, t) {{$name := .Name}} switch t { {{range .Kinds -}} @@ -319,13 +342,13 @@ const ( } ` - eCmpBoolIterRaw = `as :=isScalar(a) - bs := isScalar(b) - rs := isScalar(retVal) + eCmpBoolIterRaw = `as :=isScalar(a, t) + bs := isScalar(b, t) + rs := isScalar(retVal, Bool) rt := retVal.Bools() if ((as && !bs) || (bs && !as)) && rs { - return errors.Errorf("retVal is scalar while len(a): %d, len(b) %d", a.Len(), b.Len()) + return errors.Errorf("retVal is scalar while len(a): %d, len(b) %d", a.TypedLen(t), b.TypedLen(t)) } {{$name := .Name}} @@ -350,6 +373,59 @@ const ( return errors.Errorf("Unsupported type %v for {{$name}}", t) } ` + eMinMaxSameRaw = `as := isScalar(a, t) + bs := isScalar(b, t) + {{$name := .Name}} + switch t { + {{range .Kinds -}} + {{if isOrd . -}} + case {{reflectKind .}}: + at := a.{{sliceOf .}} + bt := b.{{sliceOf .}} + switch { + case as && bs: + Vec{{$name}}{{short .}}(at, bt) + case as && !bs: + {{$name}}SV{{short .}}(at[0], bt) + case !as && bs: + {{$name}}VS{{short .}}(at, bt[0]) + default: + Vec{{$name}}{{short .}}(at, bt) + } + return + {{end -}} + {{end -}} + default: + return errors.Errorf("Unsupported type %v for {{$name}}", t) + } + ` + + eMinMaxSameIterRaw = `as := isScalar(a, t) + bs := isScalar(b, t) + {{$name := .Name}} + switch t { + {{range .Kinds -}} + {{if isOrd . -}} + case {{reflectKind .}}: + at := a.{{sliceOf .}} + bt := b.{{sliceOf .}} + switch { + case as && bs : + Vec{{$name}}{{short .}}(at, bt) + case as && !bs: + {{$name}}IterSV{{short .}}(at[0], bt, bit) + case !as && bs: + {{$name}}IterVS{{short .}}(at, bt[0], ait) + default: + Vec{{$name}}Iter{{short .}}(at, bt, ait, bit) + } + return + {{end -}} + {{end -}} + default: + return errors.Errorf("Unsupported type %v for {{$name}}", t) + } + ` complexAbs = `{{if eq .Kind.String "complex64" -}} {{else if eq .Kind.String "complex128" -}} @@ -478,7 +554,7 @@ const ( return errors.Wrap(errors.Errorf(typeMismatch, max, maxVal), "Clamp() max") } Clamp{{short .}}(a.{{sliceOf .}}, min, max) - return nil + return nil {{end -}} default: return errors.Errorf("Unsupported type %v for Clamp", t) @@ -553,7 +629,7 @@ const ( if _, ok := err.(NoOpError); ok { err = nil } - return + return {{end -}} default: return nil, errors.Errorf("Unsupported type %v for Arg{{.Name}}", t) @@ -588,6 +664,7 @@ var ( eArithIncr *template.Template eArithIter *template.Template eArithIterIncr *template.Template + eArithRecv *template.Template eMap *template.Template eMapIter *template.Template @@ -597,6 +674,9 @@ var ( eCmpBoolIter *template.Template eCmpSameIter *template.Template + eMinMaxSame *template.Template + eMinMaxIter *template.Template + eReduce *template.Template eReduceFirst *template.Template eReduceLast *template.Template @@ -618,6 +698,7 @@ func init() { eArithIncr = template.Must(template.New("eArithIncr").Funcs(funcs).Parse(eArithIncrRaw)) eArithIter = template.Must(template.New("eArithIter").Funcs(funcs).Parse(eArithIterRaw)) eArithIterIncr = template.Must(template.New("eArithIterIncr").Funcs(funcs).Parse(eArithIterIncrRaw)) + eArithRecv = template.Must(template.New("eArithRecv").Funcs(funcs).Parse(eArithRecvRaw)) eMap = template.Must(template.New("eMap").Funcs(funcs).Parse(eMapRaw)) eMapIter = template.Must(template.New("eMapIter").Funcs(funcs).Parse(eMapIterRaw)) @@ -627,6 +708,9 @@ func init() { eCmpBoolIter = template.Must(template.New("eCmpBoolIter").Funcs(funcs).Parse(eCmpBoolIterRaw)) eCmpSameIter = template.Must(template.New("eCmpSameIter").Funcs(funcs).Parse(eCmpSameIterRaw)) + eMinMaxSame = template.Must(template.New("eMinMaxSame").Funcs(funcs).Parse(eMinMaxSameRaw)) + eMinMaxIter = template.Must(template.New("eMinMaxSameIter").Funcs(funcs).Parse(eMinMaxSameIterRaw)) + eReduce = template.Must(template.New("eReduce").Funcs(funcs).Parse(eReduceRaw)) eReduceFirst = template.Must(template.New("eReduceFirst").Funcs(funcs).Parse(eReduceFirstRaw)) eReduceLast = template.Must(template.New("eReduceLast").Funcs(funcs).Parse(eReduceLastRaw)) diff --git a/genlib2/agg2_body.go b/genlib2/agg2_body.go index 81141fc..6f85f90 100644 --- a/genlib2/agg2_body.go +++ b/genlib2/agg2_body.go @@ -5,21 +5,40 @@ import "text/template" // level 2 aggregation (tensor.StdEng) templates const cmpPrepRaw = `var safe, same bool - if reuse, safe, _, _, same, err = handleFuncOpts({{.VecVar}}.Shape(), {{.VecVar}}.Dtype(), false, opts...); err != nil{ + var ctx context.Context + if ctx, reuse, safe, _, _, same, err = handleFuncOpts({{.VecVar}}.Shape(), {{.VecVar}}.Dtype(), {{.VecVar}}.DataOrder(),false, opts...); err != nil{ return nil, errors.Wrap(err, "Unable to handle funcOpts") } if !safe { same = true } + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be noopError{}, no need to wrap. + } ` const arithPrepRaw = `var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, _, err = handleFuncOpts({{.VecVar}}.Shape(), {{.VecVar}}.Dtype(), true, opts...); err != nil{ + var ctx context.Context + if ctx, reuse, safe, toReuse, incr, _, err = handleFuncOpts({{.VecVar}}.Shape(), {{.VecVar}}.Dtype(), {{.VecVar}}.DataOrder(), true, opts...); err != nil{ return nil, errors.Wrap(err, "Unable to handle funcOpts") - } + } + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be noopError{}, no need to wrap. + } +` + +const minmaxPrepRaw = `var safe bool + var ctx context.Context + if ctx, reuse, safe, _, _, _, err = handleFuncOpts({{.VecVar}}.Shape(), {{.VecVar}}.Dtype(), {{.VecVar}}.DataOrder(), true, opts...); err != nil{ + return nil, errors.Wrap(err, "Unable to handle funcOpts") + } + if err = handleCtx(ctx); err !=nil{ + return nil, err // this err will be noopError{}, no need to wrap. + } ` -const prepVVRaw = `if err = binaryCheck(a, b, {{.TypeClassCheck | lower}}Types); err != nil { +const prepVVRaw = `if err = binaryCheck(a, b, dtype.{{.TypeClassCheck}}); err != nil { + return nil, errors.Wrapf(err, "{{.Name}} failed") } @@ -36,10 +55,14 @@ const prepVVRaw = `if err = binaryCheck(a, b, {{.TypeClassCheck | lower}}Types); } ` -const prepMixedRaw = `if err = unaryCheck(t, {{.TypeClassCheck | lower}}Types); err != nil { +const prepMixedRaw = `if err = unaryCheck(t, dtype.{{.TypeClassCheck}}); err != nil { return nil, errors.Wrapf(err, "{{.Name}} failed") } + if err = scalarDtypeCheck(t, s); err != nil { + return nil, errors.Wrap(err, "{{.Name}} failed") + } + var reuse DenseTensor {{template "prep" . -}} @@ -47,15 +70,15 @@ const prepMixedRaw = `if err = unaryCheck(t, {{.TypeClassCheck | lower}}Types); typ := t.Dtype().Type var ait, bit, iit Iterator var dataA, dataB, dataReuse, scalarHeader *storage.Header - var useIter bool + var useIter, newAlloc bool if leftTensor { - if dataA, dataB, dataReuse, ait, iit, useIter, err = prepDataVS(t, s, reuse); err != nil { + if dataA, dataB, dataReuse, ait, iit, useIter, newAlloc, err = prepDataVS(t, s, reuse); err != nil { return nil, errors.Wrapf(err, opFail, "StdEng.{{.Name}}") } scalarHeader = dataB } else { - if dataA, dataB, dataReuse, bit, iit, useIter, err = prepDataSV(s, t, reuse); err != nil { + if dataA, dataB, dataReuse, bit, iit, useIter, newAlloc, err = prepDataSV(s, t, reuse); err != nil { return nil, errors.Wrapf(err, opFail, "StdEng.{{.Name}}") } scalarHeader = dataA @@ -63,15 +86,19 @@ const prepMixedRaw = `if err = unaryCheck(t, {{.TypeClassCheck | lower}}Types); ` -const prepUnaryRaw = `if err = unaryCheck(a, {{.TypeClassCheck | lower}}Types); err != nil { +const prepUnaryRaw = `if err = unaryCheck(a, dtype.{{.TypeClassCheck}}); err != nil { err = errors.Wrapf(err, "{{.Name}} failed") return } var reuse DenseTensor var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), true, opts...); err != nil { + var ctx context.Context + if ctx, reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } + if err = handleCtx(ctx); err != nil{ + return nil, err // this err will be a noopError{}, no need to wrap. + } typ := a.Dtype().Type var ait, rit Iterator @@ -88,7 +115,7 @@ const agg2BodyRaw = `if useIter { case incr: err = e.E.{{.Name}}IterIncr(typ, dataA, dataB, dataReuse, ait, bit, iit) retVal = reuse - {{if .VV -}} + {{if .VV -}} case toReuse: storage.CopyIter(typ,dataReuse, dataA, iit, ait) ait.Reset() @@ -114,7 +141,7 @@ const agg2BodyRaw = `if useIter { retVal = a default: {{if .VV -}} - if swap{ + if swap { retVal = b.Clone().(Tensor) }else{ retVal = a.Clone().(Tensor) @@ -124,12 +151,17 @@ const agg2BodyRaw = `if useIter { retVal = a.Clone().(Tensor) if leftTensor { err = e.E.{{.Name}}Iter(typ, retVal.hdr(), dataB, ait, bit) - }else { + } else { err = e.E.{{.Name}}Iter(typ, dataA, retVal.hdr(), ait, bit) } {{end -}} } - {{if not .VV -}}returnHeader(scalarHeader){{end}} + {{if not .VV -}} + if newAlloc{ + freeScalar(scalarHeader.Raw) + } + returnHeader(scalarHeader) + {{end -}} return } switch { @@ -138,8 +170,7 @@ const agg2BodyRaw = `if useIter { retVal = reuse {{if .VV -}} case toReuse: - storage.Copy(typ,dataReuse, dataA) - err = e.E.{{.Name}}(typ, dataReuse, dataB) + err = e.E.{{.Name}}Recv(typ, dataA, dataB, dataReuse) retVal = reuse {{else -}} case toReuse && leftTensor: @@ -149,10 +180,20 @@ const agg2BodyRaw = `if useIter { case toReuse && !leftTensor: storage.Copy(typ, dataReuse, dataB) err = e.E.{{.Name}}(typ, dataA, dataReuse) + {{if not .VV -}} + if t.Shape().IsScalarEquiv() { + storage.Copy(typ, dataReuse, dataA) + } + {{end -}} retVal = reuse {{end -}} case !safe: err = e.E.{{.Name}}(typ, dataA, dataB) + {{if not .VV -}} + if t.Shape().IsScalarEquiv() && !leftTensor { + storage.Copy(typ, dataB, dataA) + } + {{end -}} retVal = a default: {{if .VV -}} @@ -164,14 +205,18 @@ const agg2BodyRaw = `if useIter { err = e.E.{{.Name}}(typ, retVal.hdr(), dataB) {{else -}} retVal = a.Clone().(Tensor) - if leftTensor { - err = e.E.{{.Name}}(typ, retVal.hdr(), dataB) - }else { - err = e.E.{{.Name}}(typ, dataA, retVal.hdr()) + if !leftTensor { + storage.Fill(typ, retVal.hdr(), dataA) } + err = e.E.{{.Name}}(typ, retVal.hdr(), dataB) {{end -}} } - {{if not .VV -}}returnHeader(scalarHeader){{end}} + {{if not .VV -}} + if newAlloc{ + freeScalar(scalarHeader.Raw) + } + returnHeader(scalarHeader) + {{end -}} return ` @@ -195,7 +240,7 @@ const agg2CmpBodyRaw = `// check to see if anything needs to be created reuse = NewDense(Bool, a.Shape().Clone(), WithEngine(e)) dataReuse = reuse.hdr() if useIter{ - iit = IteratorFromDense(reuse) + iit = IteratorFromDense(reuse) } } @@ -229,13 +274,18 @@ const agg2CmpBodyRaw = `// check to see if anything needs to be created err = e.E.{{.Name}}Iter(typ, dataA, dataB, dataReuse, ait, bit, iit) retVal = reuse } - {{if not .VV -}}returnHeader(scalarHeader){{end}} + {{if not .VV -}} + if newAlloc{ + freeScalar(scalarHeader.Raw) + } + returnHeader(scalarHeader) + {{end -}} return } {{if not .VV -}} // handle special case where A and B have both len 1 - if dataB.L == 1 && dataB.L == 1 { + if len(dataA.Raw) == int(typ.Size()) && len(dataB.Raw) == int(typ.Size()) { switch { case same && safe && reuse != nil && leftTensor: storage.Copy(typ,dataReuse,dataA) @@ -247,7 +297,7 @@ const agg2CmpBodyRaw = `// check to see if anything needs to be created err = e.E.{{.Inv}}Same(typ, dataReuse, dataA) retVal = reuse return - } + } } {{end -}} @@ -275,7 +325,118 @@ const agg2CmpBodyRaw = `// check to see if anything needs to be created err = e.E.{{.Name}}(typ, dataA, dataB, dataReuse) retVal = reuse } - {{if not .VV -}}returnHeader(scalarHeader){{end}} + {{if not .VV -}} + if newAlloc{ + freeScalar(scalarHeader.Raw) + } + returnHeader(scalarHeader) + {{end -}} + return +` + +const agg2MinMaxBodyRaw = `// check to see if anything needs to be created + if reuse == nil{ + {{if .VV -}} + if swap{ + reuse = NewDense(b.Dtype(), b.Shape().Clone(), WithEngine(e)) + } else{ + reuse = NewDense(a.Dtype(), a.Shape().Clone(), WithEngine(e)) + } + {{else -}} + reuse = NewDense(a.Dtype(), a.Shape().Clone(), WithEngine(e)) + {{end -}} + dataReuse = reuse.hdr() + if useIter{ + iit = IteratorFromDense(reuse) + } + } + + + if useIter { + switch { + case !safe && reuse == nil: + err = e.E.{{.Name}}Iter(typ, dataA, dataB, ait, bit) + retVal = a + {{if .VV -}} + case safe && reuse != nil: + storage.CopyIter(typ,dataReuse,dataA, iit, ait) + ait.Reset() + iit.Reset() + err = e.E.{{.Name}}Iter(typ, dataReuse, dataB, iit, bit) + retVal = reuse + {{else -}} + case safe && reuse != nil && !leftTensor: + storage.CopyIter(typ,dataReuse,dataB, iit, bit) + bit.Reset() + iit.Reset() + err = e.E.{{.Name}}Iter(typ, dataA, dataReuse, ait, bit) + retVal = reuse + case safe && reuse != nil && leftTensor: + storage.CopyIter(typ,dataReuse,dataA, iit, ait) + ait.Reset() + iit.Reset() + err = e.E.{{.Name}}Iter(typ, dataReuse, dataB, iit, bit) + retVal = reuse + {{end -}} + default: // safe && bool + panic("Unreachable") + } + {{if not .VV -}} + if newAlloc{ + freeScalar(scalarHeader.Raw) + } + returnHeader(scalarHeader) + {{end -}} + return + } + + {{if not .VV -}} + // handle special case where A and B have both len 1 + if len(dataA.Raw) == int(typ.Size()) && len(dataB.Raw) == int(typ.Size()) { + switch { + case safe && reuse != nil && leftTensor: + storage.Copy(typ,dataReuse,dataA) + err = e.E.{{.Name}}(typ, dataReuse, dataB) + retVal = reuse + return + case safe && reuse != nil && !leftTensor: + storage.Copy(typ,dataReuse,dataB) + err = e.E.{{.Name}}(typ, dataReuse, dataA) + retVal = reuse + return + } + } + {{end -}} + + // standard + switch { + case !safe && reuse == nil: + err = e.E.{{.Name}}(typ, dataA, dataB) + retVal = a + {{if .VV -}} + case safe && reuse != nil: + storage.Copy(typ,dataReuse,dataA) + err = e.E.{{.Name}}(typ, dataReuse, dataB) + retVal = reuse + {{else -}} + case safe && reuse != nil && leftTensor: + storage.Copy(typ,dataReuse,dataA) + err = e.E.{{.Name}}(typ, dataReuse, dataB) + retVal = reuse + case safe && reuse != nil && !leftTensor: + storage.Copy(typ,dataReuse,dataB) + err = e.E.{{.Name}}(typ, dataA, dataReuse) + retVal = reuse + {{end -}} + default: + panic("Unreachable") + } + {{if not .VV -}} + if newAlloc{ + freeScalar(scalarHeader.Raw) + } + returnHeader(scalarHeader) + {{end -}} return ` @@ -329,12 +490,13 @@ const agg2UnaryBodyRaw = ` ` var ( - prepVV *template.Template - prepMixed *template.Template - prepUnary *template.Template - agg2Body *template.Template - agg2CmpBody *template.Template - agg2UnaryBody *template.Template + prepVV *template.Template + prepMixed *template.Template + prepUnary *template.Template + agg2Body *template.Template + agg2CmpBody *template.Template + agg2UnaryBody *template.Template + agg2MinMaxBody *template.Template ) func init() { @@ -344,4 +506,5 @@ func init() { agg2Body = template.Must(template.New("agg2body").Funcs(funcs).Parse(agg2BodyRaw)) agg2CmpBody = template.Must(template.New("agg2CmpBody").Funcs(funcs).Parse(agg2CmpBodyRaw)) agg2UnaryBody = template.Must(template.New("agg2UnaryBody").Funcs(funcs).Parse(agg2UnaryBodyRaw)) + agg2MinMaxBody = template.Must(template.New("agg2MinMaxBody").Funcs(funcs).Parse(agg2MinMaxBodyRaw)) } diff --git a/genlib2/agg3_body.go b/genlib2/agg3_body.go index 0eb274c..024204a 100644 --- a/genlib2/agg3_body.go +++ b/genlib2/agg3_body.go @@ -66,17 +66,22 @@ const denseIdentityArithTestBodyRaw = `iden := func(a *Dense) bool { _, ok := a.Engine().({{interfaceName .Name}}); we = we || !ok {{template "call0" . }} + {{if eq .FuncOpt "context" -}} + {{template "funcoptcheck" -}} + {{end -}} if err, retEarly := qcErrCheck(t, "{{.Name}}", a, b, we, err); retEarly{ if err != nil { return false } return true } - + if !qcEqCheck(t, a.Dtype(), willFailEq, correct.Data(), ret.Data()) { return false } - {{template "funcoptcheck" -}} + {{if ne .FuncOpt "context" -}} + {{template "funcoptcheck" -}} + {{end -}} return true } @@ -96,6 +101,9 @@ const denseIdentityArithScalarTestRaw = `iden1 := func(q *Dense) bool { _, ok := q.Engine().({{interfaceName .Name}}); we = we || !ok {{template "call0" . }} + {{if eq .FuncOpt "context" -}} + {{template "funcoptcheck" -}} + {{end -}} if err, retEarly := qcErrCheck(t, "{{.Name}}", a, b, we, err); retEarly{ if err != nil { return false @@ -106,7 +114,9 @@ const denseIdentityArithScalarTestRaw = `iden1 := func(q *Dense) bool { if !qcEqCheck(t, a.Dtype(), willFailEq, correct.Data(), ret.Data()) { return false } - {{template "funcoptcheck" -}} + {{if ne .FuncOpt "context" -}} + {{template "funcoptcheck" -}} + {{end -}} return true } @@ -127,6 +137,9 @@ iden2 := func(q *Dense) bool { _, ok := q.Engine().({{interfaceName .Name}}); we = we || !ok {{template "call1" . }} + {{if eq .FuncOpt "context" -}} + {{template "funcoptcheck" -}} + {{end -}} if err, retEarly := qcErrCheck(t, "{{.Name}}", a, b, we, err); retEarly{ if err != nil { return false @@ -137,7 +150,9 @@ iden2 := func(q *Dense) bool { if !qcEqCheck(t, a.Dtype(), willFailEq, correct.Data(), ret.Data()) { return false } - {{template "funcoptcheck" -}} + {{if ne .FuncOpt "context" -}} + {{template "funcoptcheck" -}} + {{end -}} return true } @@ -160,6 +175,9 @@ const denseInvArithTestBodyRaw = `inv := func(a *Dense) bool { _, ok := a.Engine().({{interfaceName .Name}}); we = we || !ok {{template "call0" . }} + {{if eq .FuncOpt "context" -}} + {{template "funcoptcheck" -}} + {{end -}} if err, retEarly := qcErrCheck(t, "{{.Name}}", a, b, we, err); retEarly{ if err != nil { return false @@ -167,11 +185,14 @@ const denseInvArithTestBodyRaw = `inv := func(a *Dense) bool { return true } {{template "callInv" .}} - + if !qcEqCheck(t, a.Dtype(), willFailEq, correct.Data(), ret.Data()) { return false } - {{template "funcoptcheck" -}} + {{if ne .FuncOpt "context" -}} + {{template "funcoptcheck" -}} + {{end -}} + return true } @@ -191,6 +212,9 @@ const denseInvArithScalarTestRaw = `inv1 := func(q *Dense) bool { _, ok := q.Engine().({{interfaceName .Name}}); we = we || !ok {{template "call0" . }} + {{if eq .FuncOpt "context" -}} + {{template "funcoptcheck" -}} + {{end -}} if err, retEarly := qcErrCheck(t, "{{.Name}}VS", a, b, we, err); retEarly{ if err != nil { return false @@ -202,7 +226,9 @@ const denseInvArithScalarTestRaw = `inv1 := func(q *Dense) bool { if !qcEqCheck(t, a.Dtype(), willFailEq, correct.Data(), ret.Data()) { return false } - {{template "funcoptcheck" -}} + {{if ne .FuncOpt "context" -}} + {{template "funcoptcheck" -}} + {{end -}} return true } @@ -224,6 +250,9 @@ inv2 := func(q *Dense) bool { _, ok := q.Engine().({{interfaceName .Name}}); we = we || !ok {{template "call1" . }} + {{if eq .FuncOpt "context" -}} + {{template "funcoptcheck" -}} + {{end -}} if err, retEarly := qcErrCheck(t, "{{.Name}}SV", a, b, we, err); retEarly{ if err != nil { return false @@ -235,7 +264,9 @@ inv2 := func(q *Dense) bool { if !qcEqCheck(t, a.Dtype(), willFailEq, correct.Data(), ret.Data()) { return false } - {{template "funcoptcheck" -}} + {{if ne .FuncOpt "context" -}} + {{template "funcoptcheck" -}} + {{end -}} return true } @@ -246,12 +277,97 @@ if err := quick.Check(inv2, &quick.Config{Rand: newRand(), MaxCount: quickchecks {{end -}} ` +const denseArithScalarWrongTypeTestRaw = `type Foo int +wt1 := func(a *Dense) bool{ + b := Foo(0) + {{template "call0" .}} + if err == nil { + return false + } + _ = ret + return true +} +if err := quick.Check(wt1, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { + t.Errorf("WrongType test for {{.Name}} (tensor as left, scalar as right) failed: %v", err) +} + +wt2 := func(a *Dense) bool{ + b := Foo(0) + {{template "call1" .}} + if err == nil { + return false + } + _ = ret + return true +} +if err := quick.Check(wt2, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { + t.Errorf("WrongType test for {{.Name}} (tensor as right, scalar as left) failed: %v", err) +} +` + +const denseArithReuseMutationTestRaw = `mut := func(a, b *Dense, reuseA bool) bool { + // req because we're only testing on one kind of tensor/engine combo + a.e = StdEng{} + a.oe = StdEng{} + a.flag = 0 + b.e = StdEng{} + b.oe = StdEng{} + b.flag = 0 + + if a.Dtype() != b.Dtype(){ + return true + } + if !a.Shape().Eq(b.Shape()){ + return true + } + + + + {{template "callVanilla" .}} + we, willFailEq := willerr(a, {{.TypeClassName}}, {{.EqFailTypeClassName}}) + _, ok := a.Engine().({{interfaceName .Name}}); we = we || !ok + + + + var ret, reuse {{template "retType" .}} + if reuseA { + {{template "call0" .}}, WithReuse(a)) + reuse = a + } else { + {{template "call0" .}}, WithReuse(b)) + reuse = b + } + + + if err, retEarly := qcErrCheck(t, "{{.Name}}", a, b, we, err); retEarly{ + if err != nil { + return false + } + return true + } + + if !qcEqCheck(t, a.Dtype(), willFailEq, correct.Data(), ret.Data()) { + return false + } + + {{template "funcoptcheck" -}} + + return true +} +if err := quick.Check(mut, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { + t.Errorf("Reuse Mutation test for {{.Name}} failed: %v", err) +} + +` + var ( denseArithBody *template.Template denseArithScalarBody *template.Template denseIdentityArithTest *template.Template denseIdentityArithScalarTest *template.Template + + denseArithScalarWrongTypeTest *template.Template ) func init() { @@ -260,4 +376,6 @@ func init() { denseIdentityArithTest = template.Must(template.New("dense identity test").Funcs(funcs).Parse(denseIdentityArithTestBodyRaw)) denseIdentityArithScalarTest = template.Must(template.New("dense scalar identity test").Funcs(funcs).Parse(denseIdentityArithScalarTestRaw)) + + denseArithScalarWrongTypeTest = template.Must(template.New("dense scalar wrongtype test").Funcs(funcs).Parse(denseArithScalarWrongTypeTestRaw)) } diff --git a/genlib2/arith_tests.go b/genlib2/arith_tests.go index e53f859..77b1647 100644 --- a/genlib2/arith_tests.go +++ b/genlib2/arith_tests.go @@ -7,21 +7,28 @@ import ( ) const ( - APICallVVRaw = `ret, err := {{.Name}}(a, b {{template "funcoptuse"}})` - APICallVSRaw = `ret, err := {{.Name}}(a, b {{template "funcoptuse"}})` - APICallSVRaw = `ret, err := {{.Name}}(b, a {{template "funcoptuse"}})` + APICallVVxRaw = `correct, err := {{.Name}}(a, b)` // no funcopt + APICallVVReuseMutRaw = `ret, err = {{.Name}}(a, b` + APICallVVRaw = `ret, err := {{.Name}}(a, b {{template "funcoptuse"}})` + APICallVSRaw = `ret, err := {{.Name}}(a, b {{template "funcoptuse"}})` + APICallSVRaw = `ret, err := {{.Name}}(b, a {{template "funcoptuse"}})` APIInvVVRaw = `ret, err = {{.Inv}}(ret, b, UseUnsafe())` APIInvVSRaw = `ret, err = {{.Inv}}(ret, b, UseUnsafe())` APIInvSVRaw = `ret, err = {{.Name}}(b, ret, UseUnsafe())` - DenseMethodCallVVRaw = `ret, err := a.{{.Name}}(b {{template "funcoptuse"}})` - DenseMethodCallVSRaw = `ret, err := a.{{.Name}}Scalar(b, true {{template "funcoptuse"}})` - DenseMethodCallSVRaw = `ret, err := a.{{.Name}}Scalar(b, false {{template "funcoptuse"}})` + DenseMethodCallVVxRaw = `correct, err := a.{{.Name}}(b)` // no funcopt + DenseMethodCallVVReuseMutRaw = `ret, err = a.{{.Name}}(b` + DenseMethodCallVVRaw = `ret, err := a.{{.Name}}(b {{template "funcoptuse"}})` + DenseMethodCallVSRaw = `ret, err := a.{{.Name}}Scalar(b, true {{template "funcoptuse"}})` + DenseMethodCallSVRaw = `ret, err := a.{{.Name}}Scalar(b, false {{template "funcoptuse"}})` DenseMethodInvVVRaw = `ret, err = ret.{{.Inv}}(b, UseUnsafe())` DenseMethodInvVSRaw = `ret, err = ret.{{.Inv}}Scalar(b, true, UseUnsafe())` DenseMethodInvSVRaw = `ret, err = ret.{{.Name}}Scalar(b, false, UseUnsafe())` + + APIRetType = `Tensor` + DenseRetType = `*Dense` ) type ArithTest struct { @@ -64,6 +71,11 @@ func (fn *ArithTest) WriteBody(w io.Writer) { if fn.IsInv { fn.writeInv(w) } + fn.WriteScalarWrongType(w) + + if fn.FuncOpt == "reuse" && fn.arithOp.Name() != "Pow" { + fn.writeReuseMutate(w) + } } func (fn *ArithTest) canWrite() bool { @@ -143,6 +155,41 @@ func (fn *ArithTest) writeInv(w io.Writer) { t.Execute(w, fn) } +func (fn *ArithTest) writeReuseMutate(w io.Writer) { + t := template.Must(template.New("Reuse mutation test").Funcs(funcs).Parse(denseArithReuseMutationTestRaw)) + switch fn.lvl { + case API: + return // tmp + case Dense: + template.Must(t.New("callVanilla").Parse(DenseMethodCallVVxRaw)) + template.Must(t.New("retType").Parse(DenseRetType)) + template.Must(t.New("call0").Parse(DenseMethodCallVVReuseMutRaw)) + + } + template.Must(t.New("funcoptdecl").Parse(funcOptDecl[fn.FuncOpt])) + template.Must(t.New("funcoptuse").Parse(funcOptUse[fn.FuncOpt])) + template.Must(t.New("funcoptcheck").Parse(funcOptCheck[fn.FuncOpt])) + t.Execute(w, fn) +} + +func (fn *ArithTest) WriteScalarWrongType(w io.Writer) { + if !fn.scalars { + return + } + if fn.FuncOpt != "" { + return + } + t := template.Must(template.New("dense scalar wrongtype test").Funcs(funcs).Parse(denseArithScalarWrongTypeTestRaw)) + template.Must(t.New("call0").Parse(APICallVSRaw)) + template.Must(t.New("call1").Parse(APICallSVRaw)) + template.Must(t.New("funcoptdecl").Parse(funcOptDecl[fn.FuncOpt])) + template.Must(t.New("funcoptcorrect").Parse(funcOptCorrect[fn.FuncOpt])) + template.Must(t.New("funcoptuse").Parse(funcOptUse[fn.FuncOpt])) + template.Must(t.New("funcoptcheck").Parse(funcOptCheck[fn.FuncOpt])) + + t.Execute(w, fn) +} + func (fn *ArithTest) Write(w io.Writer) { sig := fn.Signature() w.Write([]byte("func ")) @@ -158,10 +205,10 @@ func generateAPIArithTests(f io.Writer, ak Kinds) { t := &ArithTest{ arithOp: op, lvl: API, - EqFailTypeClassName: "nil", + EqFailTypeClassName: "nilTC", } if t.name == "Pow" { - t.EqFailTypeClassName = "complexTypes" + t.EqFailTypeClassName = "dtype.Complexes" } tests = append(tests, t) } @@ -187,6 +234,13 @@ func generateAPIArithTests(f io.Writer, ak Kinds) { fn.FuncOpt = "incr" } + for _, fn := range tests { + if fn.canWrite() { + fn.Write(f) + } + fn.FuncOpt = "context" + } + for _, fn := range tests { if fn.canWrite() { fn.Write(f) @@ -201,13 +255,13 @@ func generateAPIArithScalarTests(f io.Writer, ak Kinds) { arithOp: op, scalars: true, lvl: API, - EqFailTypeClassName: "nil", + EqFailTypeClassName: "nilTC", } switch t.name { case "Pow": - t.EqFailTypeClassName = "complexTypes" + t.EqFailTypeClassName = "dtype.Complexes" case "Sub": - t.EqFailTypeClassName = "unsignedTypes" + t.EqFailTypeClassName = "dtype.Unsigned" } tests = append(tests, t) } @@ -233,6 +287,13 @@ func generateAPIArithScalarTests(f io.Writer, ak Kinds) { fn.FuncOpt = "incr" } + for _, fn := range tests { + if fn.canWrite() { + fn.Write(f) + } + fn.FuncOpt = "context" + } + for _, fn := range tests { if fn.canWrite() { fn.Write(f) @@ -246,10 +307,10 @@ func generateDenseMethodArithTests(f io.Writer, ak Kinds) { t := &ArithTest{ arithOp: op, lvl: Dense, - EqFailTypeClassName: "nil", + EqFailTypeClassName: "nilTC", } if t.name == "Pow" { - t.EqFailTypeClassName = "complexTypes" + t.EqFailTypeClassName = "dtype.Complexes" } tests = append(tests, t) } @@ -289,13 +350,13 @@ func generateDenseMethodScalarTests(f io.Writer, ak Kinds) { arithOp: op, scalars: true, lvl: Dense, - EqFailTypeClassName: "nil", + EqFailTypeClassName: "nilTC", } switch t.name { case "Pow": - t.EqFailTypeClassName = "complexTypes" + t.EqFailTypeClassName = "dtype.Complexes" case "Sub": - t.EqFailTypeClassName = "unsignedTypes" + t.EqFailTypeClassName = "dtype.Unsigned" } tests = append(tests, t) } diff --git a/genlib2/array_getset.go b/genlib2/array_getset.go index 8fa0702..a75edd5 100644 --- a/genlib2/array_getset.go +++ b/genlib2/array_getset.go @@ -6,7 +6,7 @@ import ( "text/template" ) -const asSliceRaw = `func (h *Header) {{asType . | strip | title}}s() []{{asType .}} { return *(*[]{{asType .}})(unsafe.Pointer(h)) } +const asSliceRaw = `func (h *Header) {{asType . | strip | title}}s() []{{asType .}} {return (*(*[]{{asType .}})(unsafe.Pointer(&h.Raw)))[:h.TypedLen({{short . | unexport}}Type):h.TypedLen({{short . | unexport}}Type)]} ` const setBasicRaw = `func (h *Header) Set{{short . }}(i int, x {{asType . }}) { h.{{sliceOf .}}[i] = x } @@ -23,18 +23,17 @@ func (a *array) Get(i int) interface{} { {{else -}} case reflect.{{reflectKind .}}: return a.{{getOne .}}(i) - {{end -}} + {{end -}}; {{end -}} default: - at := uintptr(a.Ptr) + uintptr(i) * a.t.Size() - val := reflect.NewAt(a.t, unsafe.Pointer(at)) + val := reflect.NewAt(a.t.Type, storage.ElementAt(i, unsafe.Pointer(&a.Header.Raw[0]), a.t.Size())) val = reflect.Indirect(val) return val.Interface() } } ` -const setRaw = `// Set sets the value of the underlying array at the index i. +const setRaw = `// Set sets the value of the underlying array at the index i. func (a *array) Set(i int, x interface{}) { switch a.t.Kind() { {{range .Kinds -}} @@ -47,9 +46,7 @@ func (a *array) Set(i int, x interface{}) { {{end -}} default: xv := reflect.ValueOf(x) - ptr := uintptr(a.Ptr) - want := ptr + uintptr(i)*a.t.Size() - val := reflect.NewAt(a.t, unsafe.Pointer(want)) + val := reflect.NewAt(a.t.Type, storage.ElementAt(i, unsafe.Pointer(&a.Header.Raw[0]), a.t.Size())) val = reflect.Indirect(val) val.Set(xv) } @@ -75,12 +72,11 @@ func (a *array) Memset(x interface{}) error { {{end -}} {{end -}} } - + xv := reflect.ValueOf(x) - ptr := uintptr(a.Ptr) - for i := 0; i < a.L; i++ { - want := ptr + uintptr(i)*a.t.Size() - val := reflect.NewAt(a.t, unsafe.Pointer(want)) + l := a.Len() + for i := 0; i < l; i++ { + val := reflect.NewAt(a.t.Type, storage.ElementAt(i, unsafe.Pointer(&a.Header.Raw[0]), a.t.Size())) val = reflect.Indirect(val) val.Set(xv) } @@ -95,16 +91,17 @@ func (a array) Eq(other interface{}) bool { return false } - if oa.L != a.L { + if oa.Len() != a.Len() { return false } - + /* if oa.C != a.C { return false } + */ // same exact thing - if uintptr(oa.Ptr) == uintptr(a.Ptr){ + if uintptr(unsafe.Pointer(&oa.Header.Raw[0])) == uintptr(unsafe.Pointer(&a.Header.Raw[0])){ return true } @@ -121,7 +118,7 @@ func (a array) Eq(other interface{}) bool { {{end -}} {{end -}} default: - for i := 0; i < a.L; i++{ + for i := 0; i < a.Len(); i++{ if !reflect.DeepEqual(a.Get(i), oa.Get(i)){ return false } @@ -179,30 +176,28 @@ const copyArrayIterRaw = `func copyArrayIter(dst, src array, diter, siter Iterat ` const memsetIterRaw = ` -func (t *array) memsetIter(x interface{}, it Iterator) (err error) { +func (a *array) memsetIter(x interface{}, it Iterator) (err error) { var i int - switch t.t{ + switch a.t{ {{range .Kinds -}} {{if isParameterized . -}} {{else -}} case {{reflectKind .}}: xv, ok := x.({{asType .}}) if !ok { - return errors.Errorf(dtypeMismatch, t.t, x) + return errors.Errorf(dtypeMismatch, a.t, x) } - data := t.{{sliceOf .}} + data := a.{{sliceOf .}} for i, err = it.Next(); err == nil; i, err = it.Next(){ - data[i] = xv + data[i] = xv } err = handleNoOp(err) {{end -}} {{end -}} default: xv := reflect.ValueOf(x) - ptr := uintptr(t.Ptr) for i, err = it.Next(); err == nil; i, err = it.Next(){ - want := ptr + uintptr(i)*t.t.Size() - val := reflect.NewAt(t.t, unsafe.Pointer(want)) + val := reflect.NewAt(a.t.Type, storage.ElementAt(i, unsafe.Pointer(&a.Header.Raw[0]), a.t.Size())) val = reflect.Indirect(val) val.Set(xv) } @@ -213,14 +208,14 @@ func (t *array) memsetIter(x interface{}, it Iterator) (err error) { ` -const zeroIterRaw = `func (t *array) zeroIter(it Iterator) (err error){ +const zeroIterRaw = `func (a *array) zeroIter(it Iterator) (err error){ var i int - switch t.t { + switch a.t { {{range .Kinds -}} {{if isParameterized . -}} {{else -}} case {{reflectKind .}}: - data := t.{{sliceOf .}} + data := a.{{sliceOf .}} for i, err = it.Next(); err == nil; i, err = it.Next(){ data[i] = {{if eq .String "bool" -}} false @@ -232,12 +227,10 @@ const zeroIterRaw = `func (t *array) zeroIter(it Iterator) (err error){ {{end -}} {{end -}} default: - ptr := uintptr(t.Ptr) for i, err = it.Next(); err == nil; i, err = it.Next(){ - want := ptr + uintptr(i)*t.t.Size() - val := reflect.NewAt(t.t, unsafe.Pointer(want)) + val := reflect.NewAt(a.t.Type, storage.ElementAt(i, unsafe.Pointer(&a.Header.Raw[0]), a.t.Size())) val = reflect.Indirect(val) - val.Set(reflect.Zero(t.t)) + val.Set(reflect.Zero(a.t)) } err = handleNoOp(err) } @@ -245,16 +238,26 @@ const zeroIterRaw = `func (t *array) zeroIter(it Iterator) (err error){ } ` +const reflectConstTemplateRaw = `var ( + {{range .Kinds -}} + {{if isParameterized . -}} + {{else -}} + {{short . | unexport}}Type = reflect.TypeOf({{asType .}}({{if eq .String "bool" -}} false {{else if eq .String "string" -}}"" {{else if eq .String "unsafe.Pointer" -}}nil {{else -}}0{{end -}})) + {{end -}} + {{end -}} +)` + var ( - AsSlice *template.Template - SimpleSet *template.Template - SimpleGet *template.Template - Get *template.Template - Set *template.Template - Memset *template.Template - MemsetIter *template.Template - Eq *template.Template - ZeroIter *template.Template + AsSlice *template.Template + SimpleSet *template.Template + SimpleGet *template.Template + Get *template.Template + Set *template.Template + Memset *template.Template + MemsetIter *template.Template + Eq *template.Template + ZeroIter *template.Template + ReflectType *template.Template ) func init() { @@ -267,6 +270,7 @@ func init() { MemsetIter = template.Must(template.New("MemsetIter").Funcs(funcs).Parse(memsetIterRaw)) Eq = template.Must(template.New("ArrayEq").Funcs(funcs).Parse(arrayEqRaw)) ZeroIter = template.Must(template.New("Zero").Funcs(funcs).Parse(zeroIterRaw)) + ReflectType = template.Must(template.New("ReflectType").Funcs(funcs).Parse(reflectConstTemplateRaw)) } func generateArrayMethods(f io.Writer, ak Kinds) { @@ -295,3 +299,8 @@ func generateHeaderGetSet(f io.Writer, ak Kinds) { } } } + +func generateReflectTypes(f io.Writer, ak Kinds) { + ReflectType.Execute(f, ak) + fmt.Fprintf(f, "\n\n\n") +} diff --git a/genlib2/cmp_tests.go b/genlib2/cmp_tests.go index e0189e7..1110e6e 100644 --- a/genlib2/cmp_tests.go +++ b/genlib2/cmp_tests.go @@ -1,471 +1,471 @@ -package main - -import ( - "fmt" - "io" - "text/template" -) - -const ( - APICallVVaxbRaw = `axb, err := {{.Name}}(a, b {{template "funcoptuse" . -}})` - APICallVVbxcRaw = `bxc, err := {{.Name}}(b, c {{template "funcoptuse" . -}})` - APICallVVaxcRaw = `axc, err := {{.Name}}(a, c {{template "funcoptuse" . -}})` - APICallVVbxaRaw = `bxa, err := {{.Name}}(b, a {{template "funcoptuse" . -}})` - APICallMixedaxbRaw = `axb, err := {{.Name}}(a, b {{template "funcoptuse" . -}})` - APICallMixedbxcRaw = `bxc, err := {{.Name}}(b, c {{template "funcoptuse" . -}})` - APICallMixedaxcRaw = `axc, err := {{.Name}}(a, c {{template "funcoptuse" . -}})` - APICallMixedbxaRaw = `bxa, err := {{.Name}}(b, a {{template "funcoptuse" . -}})` - - DenseMethodCallVVaxbRaw = `axb, err := a.{{.Name}}(b {{template "funcoptuse" . -}})` - DenseMethodCallVVbxcRaw = `bxc, err := b.{{.Name}}(c {{template "funcoptuse" . -}})` - DenseMethodCallVVaxcRaw = `axc, err := a.{{.Name}}(c {{template "funcoptuse" . -}})` - DenseMethodCallVVbxaRaw = `bxa, err := b.{{.Name}}(a {{template "funcoptuse" . -}})` - DenseMethodCallMixedaxbRaw = `axb, err := a.{{.Name}}Scalar(b, true {{template "funcoptuse" . -}})` - DenseMethodCallMixedbxcRaw = `bxc, err := c.{{.Name}}Scalar(b, false {{template "funcoptuse" . -}})` - DenseMethodCallMixedaxcRaw = `axc, err := a.{{.Name}}(c {{template "funcoptuse" . -}})` - DenseMethodCallMixedbxaRaw = `bxa, err := a.{{.Name}}Scalar(b, false {{template "funcoptuse" . -}})` -) - -const transitivityCheckRaw = `{{if eq .FuncOpt "assame" -}} - if !threewayEq(axb.Data(), bxc.Data(), axc.Data()){ - t.Errorf("a: %-v", a) - t.Errorf("b: %-v", b) - t.Errorf("c: %-v", c) - t.Errorf("axb.Data() %v", axb.Data()) - t.Errorf("bxc.Data() %v", bxc.Data()) - t.Errorf("axc.Data() %v", axc.Data()) - return false - } -{{else -}} - {{if eq .Level "API" -}} - ab := axb.(*Dense).Bools() - bc := bxc.(*Dense).Bools() - ac := axc.(*Dense).Bools() - {{else -}} - ab := axb.Bools() - bc := bxc.Bools() - ac := axc.Bools() - {{end -}} - for i, vab := range ab { - if vab && bc[i] { - if !ac[i]{ - return false - } - } - } -{{end -}} -` - -const transitivityBodyRaw = `transFn := func(q *Dense) bool { - we, _ := willerr(q, {{.TypeClassName}}, {{.EqFailTypeClassName}}) - _, ok := q.Engine().({{interfaceName .Name}}); we = we || !ok - - {{template "funcoptdecl" . -}} - - r := newRand() - a := q.Clone().(*Dense) - b := q.Clone().(*Dense) - c := q.Clone().(*Dense) - - bv, _ := quick.Value(b.Dtype().Type, r) - cv, _ := quick.Value(c.Dtype().Type, r) - b.Memset(bv.Interface()) - c.Memset(cv.Interface()) - - {{template "axb" .}} - if err, retEarly := qcErrCheck(t, "{{.Name}} - a∙b", a, b, we, err); retEarly{ - if err != nil { - return false - } - return true - } - - {{template "bxc" . }} - if err, retEarly := qcErrCheck(t, "{{.Name}} - b∙c", b, c, we, err); retEarly{ - if err != nil { - return false - } - return true - } - - {{template "axc" . }} - if err, retEarly := qcErrCheck(t, "{{.Name}} - a∙c", a, c, we, err); retEarly{ - if err != nil { - return false - } - return true - } - - {{template "transitivityCheck" .}} - return true -} -if err := quick.Check(transFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Transitivity test for {{.Name}} failed: %v", err) -} -` - -const transitivityMixedBodyRaw = `transFn := func(q *Dense) bool { - we, _ := willerr(q, {{.TypeClassName}}, {{.EqFailTypeClassName}}) - _, ok := q.Engine().({{interfaceName .Name}}); we = we || !ok - - {{template "funcoptdecl" . -}} - - r := newRand() - a := q.Clone().(*Dense) - bv, _ := quick.Value(a.Dtype().Type, r) - b := bv.Interface() - c := q.Clone().(*Dense) - cv, _ := quick.Value(c.Dtype().Type, r) - c.Memset(cv.Interface()) - - {{template "axb" . }} - if err, retEarly := qcErrCheck(t, "{{.Name}} - a∙b", a, b, we, err); retEarly{ - if err != nil { - return false - } - return true - } - - {{template "bxc" . }} - if err, retEarly := qcErrCheck(t, "{{.Name}} - b∙c", c, b, we, err); retEarly{ - if err != nil { - return false - } - return true - } - - {{template "axc" . }} - if err, retEarly := qcErrCheck(t, "{{.Name}} - a∙c", a, c, we, err); retEarly{ - if err != nil { - return false - } - return true - } - - {{template "transitivityCheck" .}} - return true -} -if err := quick.Check(transFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Transitivity test for {{.Name}} failed: %v", err) -} -` - -const symmetryBodyRaw = `symFn := func(q *Dense) bool { - we, _ := willerr(q, {{.TypeClassName}}, {{.EqFailTypeClassName}}) - _, ok := q.Engine().({{interfaceName .Name}}); we = we || !ok - - {{template "funcoptdecl" . -}} - - r := newRand() - a := q.Clone().(*Dense) - b := q.Clone().(*Dense) - - bv, _ := quick.Value(b.Dtype().Type, r) - b.Memset(bv.Interface()) - - {{template "axb" .}} - if err, retEarly := qcErrCheck(t, "{{.Name}} - a∙b", a, b, we, err); retEarly{ - if err != nil { - return false - } - return true - } - - {{template "bxa" .}} - if err, retEarly := qcErrCheck(t, "{{.Name}} - b∙a", a, b, we, err); retEarly{ - if err != nil { - return false - } - return true - } - return reflect.DeepEqual(axb.Data(), bxa.Data()) - -} -if err := quick.Check(symFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Transitivity test for {{.Name}} failed: %v", err) -} -` - -const symmetryMixedBodyRaw = `symFn := func(q *Dense) bool { - we, _ := willerr(q, {{.TypeClassName}}, {{.EqFailTypeClassName}}) - _, ok := q.Engine().({{interfaceName .Name}}); we = we || !ok - - {{template "funcoptdecl" . -}} - - r := newRand() - a := q.Clone().(*Dense) - bv, _ := quick.Value(a.Dtype().Type, r) - b := bv.Interface() - - {{template "axb" .}} - if err, retEarly := qcErrCheck(t, "{{.Name}} - a∙b", a, b, we, err); retEarly{ - if err != nil { - return false - } - return true - } - - {{template "bxa" .}} - if err, retEarly := qcErrCheck(t, "{{.Name}} - b∙a", a, b, we, err); retEarly{ - if err != nil { - return false - } - return true - } - return reflect.DeepEqual(axb.Data(), bxa.Data()) - -} -if err := quick.Check(symFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Symmetry test for {{.Name}} failed: %v", err) -} -` - -type CmpTest struct { - cmpOp - scalars bool - lvl Level - FuncOpt string - EqFailTypeClassName string -} - -func (fn *CmpTest) Name() string { - if fn.cmpOp.Name() == "Eq" || fn.cmpOp.Name() == "Ne" { - return "El" + fn.cmpOp.Name() - } - return fn.cmpOp.Name() -} - -func (fn *CmpTest) Level() string { - switch fn.lvl { - case API: - return "API" - case Dense: - return "Dense" - } - return "" -} - -func (fn *CmpTest) Signature() *Signature { - var name string - switch fn.lvl { - case API: - name = fmt.Sprintf("Test%s", fn.cmpOp.Name()) - case Dense: - name = fmt.Sprintf("TestDense_%s", fn.Name()) - } - if fn.scalars { - name += "Scalar" - } - if fn.FuncOpt != "" { - name += "_" + fn.FuncOpt - } - return &Signature{ - Name: name, - NameTemplate: plainName, - ParamNames: []string{"t"}, - ParamTemplates: []*template.Template{testingType}, - } -} - -func (fn *CmpTest) canWrite() bool { - return fn.IsTransitive || fn.IsSymmetric -} - -func (fn *CmpTest) WriteBody(w io.Writer) { - if fn.IsTransitive { - fn.writeTransitivity(w) - fmt.Fprintf(w, "\n") - } - if fn.IsSymmetric { - fn.writeSymmetry(w) - } -} - -func (fn *CmpTest) writeTransitivity(w io.Writer) { - var t *template.Template - if fn.scalars { - t = template.Must(template.New("dense cmp transitivity test").Funcs(funcs).Parse(transitivityMixedBodyRaw)) - } else { - t = template.Must(template.New("dense cmp transitivity test").Funcs(funcs).Parse(transitivityBodyRaw)) - } - - switch fn.lvl { - case API: - if fn.scalars { - template.Must(t.New("axb").Parse(APICallMixedaxbRaw)) - template.Must(t.New("bxc").Parse(APICallMixedbxcRaw)) - template.Must(t.New("axc").Parse(APICallMixedaxcRaw)) - } else { - template.Must(t.New("axb").Parse(APICallVVaxbRaw)) - template.Must(t.New("bxc").Parse(APICallVVbxcRaw)) - template.Must(t.New("axc").Parse(APICallVVaxcRaw)) - } - case Dense: - if fn.scalars { - template.Must(t.New("axb").Parse(DenseMethodCallMixedaxbRaw)) - template.Must(t.New("bxc").Parse(DenseMethodCallMixedbxcRaw)) - template.Must(t.New("axc").Parse(DenseMethodCallMixedaxcRaw)) - } else { - template.Must(t.New("axb").Parse(DenseMethodCallVVaxbRaw)) - template.Must(t.New("bxc").Parse(DenseMethodCallVVbxcRaw)) - template.Must(t.New("axc").Parse(DenseMethodCallVVaxcRaw)) - } - } - template.Must(t.New("transitivityCheck").Parse(transitivityCheckRaw)) - template.Must(t.New("funcoptdecl").Parse(funcOptDecl[fn.FuncOpt])) - template.Must(t.New("funcoptcorrect").Parse(funcOptCorrect[fn.FuncOpt])) - template.Must(t.New("funcoptuse").Parse(funcOptUse[fn.FuncOpt])) - template.Must(t.New("funcoptcheck").Parse(funcOptCheck[fn.FuncOpt])) - - t.Execute(w, fn) -} - -func (fn *CmpTest) writeSymmetry(w io.Writer) { - var t *template.Template - if fn.scalars { - t = template.Must(template.New("dense cmp symmetry test").Funcs(funcs).Parse(symmetryMixedBodyRaw)) - } else { - t = template.Must(template.New("dense cmp symmetry test").Funcs(funcs).Parse(symmetryBodyRaw)) - } - - switch fn.lvl { - case API: - if fn.scalars { - template.Must(t.New("axb").Parse(APICallMixedaxbRaw)) - template.Must(t.New("bxa").Parse(APICallMixedbxaRaw)) - } else { - template.Must(t.New("axb").Parse(APICallVVaxbRaw)) - template.Must(t.New("bxa").Parse(APICallVVbxaRaw)) - } - case Dense: - if fn.scalars { - template.Must(t.New("axb").Parse(DenseMethodCallMixedaxbRaw)) - template.Must(t.New("bxa").Parse(DenseMethodCallMixedbxaRaw)) - } else { - template.Must(t.New("axb").Parse(DenseMethodCallVVaxbRaw)) - template.Must(t.New("bxa").Parse(DenseMethodCallVVbxaRaw)) - } - } - template.Must(t.New("funcoptdecl").Parse(funcOptDecl[fn.FuncOpt])) - template.Must(t.New("funcoptcorrect").Parse(funcOptCorrect[fn.FuncOpt])) - template.Must(t.New("funcoptuse").Parse(funcOptUse[fn.FuncOpt])) - template.Must(t.New("funcoptcheck").Parse(funcOptCheck[fn.FuncOpt])) - - t.Execute(w, fn) -} - -func (fn *CmpTest) Write(w io.Writer) { - sig := fn.Signature() - w.Write([]byte("func ")) - sig.Write(w) - w.Write([]byte("{\n")) - fn.WriteBody(w) - w.Write([]byte("}\n")) -} - -func generateAPICmpTests(f io.Writer, ak Kinds) { - var tests []*CmpTest - - for _, op := range cmpBinOps { - t := &CmpTest{ - cmpOp: op, - lvl: API, - EqFailTypeClassName: "nil", - } - tests = append(tests, t) - } - - for _, fn := range tests { - if fn.canWrite() { - fn.Write(f) - } - fn.FuncOpt = "assame" - fn.TypeClassName = "nonComplexNumberTypes" - } - for _, fn := range tests { - if fn.canWrite() { - fn.Write(f) - } - } - -} - -func generateAPICmpMixedTests(f io.Writer, ak Kinds) { - var tests []*CmpTest - - for _, op := range cmpBinOps { - t := &CmpTest{ - cmpOp: op, - lvl: API, - scalars: true, - EqFailTypeClassName: "nil", - } - tests = append(tests, t) - } - - for _, fn := range tests { - if fn.canWrite() { - fn.Write(f) - } - fn.FuncOpt = "assame" - fn.TypeClassName = "nonComplexNumberTypes" - } - for _, fn := range tests { - if fn.canWrite() { - fn.Write(f) - } - } -} - -func generateDenseMethodCmpTests(f io.Writer, ak Kinds) { - var tests []*CmpTest - - for _, op := range cmpBinOps { - t := &CmpTest{ - cmpOp: op, - lvl: Dense, - EqFailTypeClassName: "nil", - } - tests = append(tests, t) - } - - for _, fn := range tests { - if fn.canWrite() { - fn.Write(f) - } - fn.FuncOpt = "assame" - fn.TypeClassName = "nonComplexNumberTypes" - } - for _, fn := range tests { - if fn.canWrite() { - fn.Write(f) - } - } -} - -func generateDenseMethodCmpMixedTests(f io.Writer, ak Kinds) { - var tests []*CmpTest - - for _, op := range cmpBinOps { - t := &CmpTest{ - cmpOp: op, - lvl: Dense, - scalars: true, - EqFailTypeClassName: "nil", - } - tests = append(tests, t) - } - - for _, fn := range tests { - if fn.canWrite() { - fn.Write(f) - } - fn.FuncOpt = "assame" - fn.TypeClassName = "nonComplexNumberTypes" - } - for _, fn := range tests { - if fn.canWrite() { - fn.Write(f) - } - } -} +package main + +import ( + "fmt" + "io" + "text/template" +) + +const ( + APICallVVaxbRaw = `axb, err := {{.Name}}(a, b {{template "funcoptuse" . -}})` + APICallVVbxcRaw = `bxc, err := {{.Name}}(b, c {{template "funcoptuse" . -}})` + APICallVVaxcRaw = `axc, err := {{.Name}}(a, c {{template "funcoptuse" . -}})` + APICallVVbxaRaw = `bxa, err := {{.Name}}(b, a {{template "funcoptuse" . -}})` + APICallMixedaxbRaw = `axb, err := {{.Name}}(a, b {{template "funcoptuse" . -}})` + APICallMixedbxcRaw = `bxc, err := {{.Name}}(b, c {{template "funcoptuse" . -}})` + APICallMixedaxcRaw = `axc, err := {{.Name}}(a, c {{template "funcoptuse" . -}})` + APICallMixedbxaRaw = `bxa, err := {{.Name}}(b, a {{template "funcoptuse" . -}})` + + DenseMethodCallVVaxbRaw = `axb, err := a.{{.Name}}(b {{template "funcoptuse" . -}})` + DenseMethodCallVVbxcRaw = `bxc, err := b.{{.Name}}(c {{template "funcoptuse" . -}})` + DenseMethodCallVVaxcRaw = `axc, err := a.{{.Name}}(c {{template "funcoptuse" . -}})` + DenseMethodCallVVbxaRaw = `bxa, err := b.{{.Name}}(a {{template "funcoptuse" . -}})` + DenseMethodCallMixedaxbRaw = `axb, err := a.{{.Name}}Scalar(b, true {{template "funcoptuse" . -}})` + DenseMethodCallMixedbxcRaw = `bxc, err := c.{{.Name}}Scalar(b, false {{template "funcoptuse" . -}})` + DenseMethodCallMixedaxcRaw = `axc, err := a.{{.Name}}(c {{template "funcoptuse" . -}})` + DenseMethodCallMixedbxaRaw = `bxa, err := a.{{.Name}}Scalar(b, false {{template "funcoptuse" . -}})` +) + +const transitivityCheckRaw = `{{if eq .FuncOpt "assame" -}} + if !threewayEq(axb.Data(), bxc.Data(), axc.Data()){ + t.Errorf("a: %-v", a) + t.Errorf("b: %-v", b) + t.Errorf("c: %-v", c) + t.Errorf("axb.Data() %v", axb.Data()) + t.Errorf("bxc.Data() %v", bxc.Data()) + t.Errorf("axc.Data() %v", axc.Data()) + return false + } +{{else -}} + {{if eq .Level "API" -}} + ab := axb.(*Dense).Bools() + bc := bxc.(*Dense).Bools() + ac := axc.(*Dense).Bools() + {{else -}} + ab := axb.Bools() + bc := bxc.Bools() + ac := axc.Bools() + {{end -}} + for i, vab := range ab { + if vab && bc[i] { + if !ac[i]{ + return false + } + } + } +{{end -}} +` + +const transitivityBodyRaw = `transFn := func(q *Dense) bool { + we, _ := willerr(q, {{.TypeClassName}}, {{.EqFailTypeClassName}}) + _, ok := q.Engine().({{interfaceName .Name}}); we = we || !ok + + {{template "funcoptdecl" . -}} + + r := newRand() + a := q.Clone().(*Dense) + b := q.Clone().(*Dense) + c := q.Clone().(*Dense) + + bv, _ := quick.Value(b.Dtype().Type, r) + cv, _ := quick.Value(c.Dtype().Type, r) + b.Memset(bv.Interface()) + c.Memset(cv.Interface()) + + {{template "axb" .}} + if err, retEarly := qcErrCheck(t, "{{.Name}} - a∙b", a, b, we, err); retEarly{ + if err != nil { + return false + } + return true + } + + {{template "bxc" . }} + if err, retEarly := qcErrCheck(t, "{{.Name}} - b∙c", b, c, we, err); retEarly{ + if err != nil { + return false + } + return true + } + + {{template "axc" . }} + if err, retEarly := qcErrCheck(t, "{{.Name}} - a∙c", a, c, we, err); retEarly{ + if err != nil { + return false + } + return true + } + + {{template "transitivityCheck" .}} + return true +} +if err := quick.Check(transFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { + t.Errorf("Transitivity test for {{.Name}} failed: %v", err) +} +` + +const transitivityMixedBodyRaw = `transFn := func(q *Dense) bool { + we, _ := willerr(q, {{.TypeClassName}}, {{.EqFailTypeClassName}}) + _, ok := q.Engine().({{interfaceName .Name}}); we = we || !ok + + {{template "funcoptdecl" . -}} + + r := newRand() + a := q.Clone().(*Dense) + bv, _ := quick.Value(a.Dtype().Type, r) + b := bv.Interface() + c := q.Clone().(*Dense) + cv, _ := quick.Value(c.Dtype().Type, r) + c.Memset(cv.Interface()) + + {{template "axb" . }} + if err, retEarly := qcErrCheck(t, "{{.Name}} - a∙b", a, b, we, err); retEarly{ + if err != nil { + return false + } + return true + } + + {{template "bxc" . }} + if err, retEarly := qcErrCheck(t, "{{.Name}} - b∙c", c, b, we, err); retEarly{ + if err != nil { + return false + } + return true + } + + {{template "axc" . }} + if err, retEarly := qcErrCheck(t, "{{.Name}} - a∙c", a, c, we, err); retEarly{ + if err != nil { + return false + } + return true + } + + {{template "transitivityCheck" .}} + return true +} +if err := quick.Check(transFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { + t.Errorf("Transitivity test for {{.Name}} failed: %v", err) +} +` + +const symmetryBodyRaw = `symFn := func(q *Dense) bool { + we, _ := willerr(q, {{.TypeClassName}}, {{.EqFailTypeClassName}}) + _, ok := q.Engine().({{interfaceName .Name}}); we = we || !ok + + {{template "funcoptdecl" . -}} + + r := newRand() + a := q.Clone().(*Dense) + b := q.Clone().(*Dense) + + bv, _ := quick.Value(b.Dtype().Type, r) + b.Memset(bv.Interface()) + + {{template "axb" .}} + if err, retEarly := qcErrCheck(t, "{{.Name}} - a∙b", a, b, we, err); retEarly{ + if err != nil { + return false + } + return true + } + + {{template "bxa" .}} + if err, retEarly := qcErrCheck(t, "{{.Name}} - b∙a", a, b, we, err); retEarly{ + if err != nil { + return false + } + return true + } + return reflect.DeepEqual(axb.Data(), bxa.Data()) + +} +if err := quick.Check(symFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { + t.Errorf("Transitivity test for {{.Name}} failed: %v", err) +} +` + +const symmetryMixedBodyRaw = `symFn := func(q *Dense) bool { + we, _ := willerr(q, {{.TypeClassName}}, {{.EqFailTypeClassName}}) + _, ok := q.Engine().({{interfaceName .Name}}); we = we || !ok + + {{template "funcoptdecl" . -}} + + r := newRand() + a := q.Clone().(*Dense) + bv, _ := quick.Value(a.Dtype().Type, r) + b := bv.Interface() + + {{template "axb" .}} + if err, retEarly := qcErrCheck(t, "{{.Name}} - a∙b", a, b, we, err); retEarly{ + if err != nil { + return false + } + return true + } + + {{template "bxa" .}} + if err, retEarly := qcErrCheck(t, "{{.Name}} - b∙a", a, b, we, err); retEarly{ + if err != nil { + return false + } + return true + } + return reflect.DeepEqual(axb.Data(), bxa.Data()) + +} +if err := quick.Check(symFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { + t.Errorf("Symmetry test for {{.Name}} failed: %v", err) +} +` + +type CmpTest struct { + cmpOp + scalars bool + lvl Level + FuncOpt string + EqFailTypeClassName string +} + +func (fn *CmpTest) Name() string { + if fn.cmpOp.Name() == "Eq" || fn.cmpOp.Name() == "Ne" { + return "El" + fn.cmpOp.Name() + } + return fn.cmpOp.Name() +} + +func (fn *CmpTest) Level() string { + switch fn.lvl { + case API: + return "API" + case Dense: + return "Dense" + } + return "" +} + +func (fn *CmpTest) Signature() *Signature { + var name string + switch fn.lvl { + case API: + name = fmt.Sprintf("Test%s", fn.cmpOp.Name()) + case Dense: + name = fmt.Sprintf("TestDense_%s", fn.Name()) + } + if fn.scalars { + name += "Scalar" + } + if fn.FuncOpt != "" { + name += "_" + fn.FuncOpt + } + return &Signature{ + Name: name, + NameTemplate: plainName, + ParamNames: []string{"t"}, + ParamTemplates: []*template.Template{testingType}, + } +} + +func (fn *CmpTest) canWrite() bool { + return fn.IsTransitive || fn.IsSymmetric +} + +func (fn *CmpTest) WriteBody(w io.Writer) { + if fn.IsTransitive { + fn.writeTransitivity(w) + fmt.Fprintf(w, "\n") + } + if fn.IsSymmetric { + fn.writeSymmetry(w) + } +} + +func (fn *CmpTest) writeTransitivity(w io.Writer) { + var t *template.Template + if fn.scalars { + t = template.Must(template.New("dense cmp transitivity test").Funcs(funcs).Parse(transitivityMixedBodyRaw)) + } else { + t = template.Must(template.New("dense cmp transitivity test").Funcs(funcs).Parse(transitivityBodyRaw)) + } + + switch fn.lvl { + case API: + if fn.scalars { + template.Must(t.New("axb").Parse(APICallMixedaxbRaw)) + template.Must(t.New("bxc").Parse(APICallMixedbxcRaw)) + template.Must(t.New("axc").Parse(APICallMixedaxcRaw)) + } else { + template.Must(t.New("axb").Parse(APICallVVaxbRaw)) + template.Must(t.New("bxc").Parse(APICallVVbxcRaw)) + template.Must(t.New("axc").Parse(APICallVVaxcRaw)) + } + case Dense: + if fn.scalars { + template.Must(t.New("axb").Parse(DenseMethodCallMixedaxbRaw)) + template.Must(t.New("bxc").Parse(DenseMethodCallMixedbxcRaw)) + template.Must(t.New("axc").Parse(DenseMethodCallMixedaxcRaw)) + } else { + template.Must(t.New("axb").Parse(DenseMethodCallVVaxbRaw)) + template.Must(t.New("bxc").Parse(DenseMethodCallVVbxcRaw)) + template.Must(t.New("axc").Parse(DenseMethodCallVVaxcRaw)) + } + } + template.Must(t.New("transitivityCheck").Parse(transitivityCheckRaw)) + template.Must(t.New("funcoptdecl").Parse(funcOptDecl[fn.FuncOpt])) + template.Must(t.New("funcoptcorrect").Parse(funcOptCorrect[fn.FuncOpt])) + template.Must(t.New("funcoptuse").Parse(funcOptUse[fn.FuncOpt])) + template.Must(t.New("funcoptcheck").Parse(funcOptCheck[fn.FuncOpt])) + + t.Execute(w, fn) +} + +func (fn *CmpTest) writeSymmetry(w io.Writer) { + var t *template.Template + if fn.scalars { + t = template.Must(template.New("dense cmp symmetry test").Funcs(funcs).Parse(symmetryMixedBodyRaw)) + } else { + t = template.Must(template.New("dense cmp symmetry test").Funcs(funcs).Parse(symmetryBodyRaw)) + } + + switch fn.lvl { + case API: + if fn.scalars { + template.Must(t.New("axb").Parse(APICallMixedaxbRaw)) + template.Must(t.New("bxa").Parse(APICallMixedbxaRaw)) + } else { + template.Must(t.New("axb").Parse(APICallVVaxbRaw)) + template.Must(t.New("bxa").Parse(APICallVVbxaRaw)) + } + case Dense: + if fn.scalars { + template.Must(t.New("axb").Parse(DenseMethodCallMixedaxbRaw)) + template.Must(t.New("bxa").Parse(DenseMethodCallMixedbxaRaw)) + } else { + template.Must(t.New("axb").Parse(DenseMethodCallVVaxbRaw)) + template.Must(t.New("bxa").Parse(DenseMethodCallVVbxaRaw)) + } + } + template.Must(t.New("funcoptdecl").Parse(funcOptDecl[fn.FuncOpt])) + template.Must(t.New("funcoptcorrect").Parse(funcOptCorrect[fn.FuncOpt])) + template.Must(t.New("funcoptuse").Parse(funcOptUse[fn.FuncOpt])) + template.Must(t.New("funcoptcheck").Parse(funcOptCheck[fn.FuncOpt])) + + t.Execute(w, fn) +} + +func (fn *CmpTest) Write(w io.Writer) { + sig := fn.Signature() + w.Write([]byte("func ")) + sig.Write(w) + w.Write([]byte("{\n")) + fn.WriteBody(w) + w.Write([]byte("}\n")) +} + +func generateAPICmpTests(f io.Writer, ak Kinds) { + var tests []*CmpTest + + for _, op := range cmpBinOps { + t := &CmpTest{ + cmpOp: op, + lvl: API, + EqFailTypeClassName: "nilTC", + } + tests = append(tests, t) + } + + for _, fn := range tests { + if fn.canWrite() { + fn.Write(f) + } + fn.FuncOpt = "assame" + fn.TypeClassName = "dtype.NonComplexNumber" + } + for _, fn := range tests { + if fn.canWrite() { + fn.Write(f) + } + } + +} + +func generateAPICmpMixedTests(f io.Writer, ak Kinds) { + var tests []*CmpTest + + for _, op := range cmpBinOps { + t := &CmpTest{ + cmpOp: op, + lvl: API, + scalars: true, + EqFailTypeClassName: "nilTC", + } + tests = append(tests, t) + } + + for _, fn := range tests { + if fn.canWrite() { + fn.Write(f) + } + fn.FuncOpt = "assame" + fn.TypeClassName = "dtype.NonComplexNumber" + } + for _, fn := range tests { + if fn.canWrite() { + fn.Write(f) + } + } +} + +func generateDenseMethodCmpTests(f io.Writer, ak Kinds) { + var tests []*CmpTest + + for _, op := range cmpBinOps { + t := &CmpTest{ + cmpOp: op, + lvl: Dense, + EqFailTypeClassName: "nilTC", + } + tests = append(tests, t) + } + + for _, fn := range tests { + if fn.canWrite() { + fn.Write(f) + } + fn.FuncOpt = "assame" + fn.TypeClassName = "dtype.NonComplexNumber" + } + for _, fn := range tests { + if fn.canWrite() { + fn.Write(f) + } + } +} + +func generateDenseMethodCmpMixedTests(f io.Writer, ak Kinds) { + var tests []*CmpTest + + for _, op := range cmpBinOps { + t := &CmpTest{ + cmpOp: op, + lvl: Dense, + scalars: true, + EqFailTypeClassName: "nilTC", + } + tests = append(tests, t) + } + + for _, fn := range tests { + if fn.canWrite() { + fn.Write(f) + } + fn.FuncOpt = "assame" + fn.TypeClassName = "dtype.NonComplexNumber" + } + for _, fn := range tests { + if fn.canWrite() { + fn.Write(f) + } + } +} diff --git a/genlib2/declarations.go b/genlib2/declarations.go index 970f4ae..9d70630 100644 --- a/genlib2/declarations.go +++ b/genlib2/declarations.go @@ -25,7 +25,7 @@ var cmpSymbolTemplates = [...]string{ } var nonFloatConditionalUnarySymbolTemplates = [...]string{ - `{{if isFloat .Kind -}} + `{{if isFloat .Kind -}} {{.Range}}[{{.Index0}}] = {{mathPkg .Kind}}Abs({{.Range}}[{{.Index0}}]) {{else -}} if {{.Range}}[{{.Index0}}] < 0 { {{.Range}}[{{.Index0}}] = -{{.Range}}[{{.Index0}}] @@ -57,10 +57,11 @@ var unconditionalFloatUnarySymbolTemplates = [...]string{ } var funcOptUse = map[string]string{ - "reuse": ",WithReuse(reuse)", - "incr": ",WithIncr(incr)", - "unsafe": ",UseUnsafe()", - "assame": ", AsSameType()", + "reuse": ",WithReuse(reuse)", + "incr": ",WithIncr(incr)", + "unsafe": ",UseUnsafe()", + "assame": ", AsSameType()", + "context": ", WithContext(ctx)", } var funcOptCheck = map[string]string{ @@ -77,7 +78,10 @@ var funcOptCheck = map[string]string{ t.Errorf("Expected ret to be the same as a") return false } - + `, + "context": `if _, ok := err.(NoOpError); ok && r < 5 { + return true // short circuit + } `, } @@ -85,10 +89,21 @@ var funcOptDecl = map[string]string{ "reuse": "reuse := New(Of(a.t), WithShape(a.Shape().Clone()...))\n", "incr": "incr := New(Of(a.t), WithShape(a.Shape().Clone()...))\n", "unsafe": "", - "assame": `if err := typeclassCheck(q.Dtype(), {{.TypeClassName}}); err != nil { + "assame": `if err := dtype.TypeClassCheck(q.Dtype(), {{.TypeClassName}}); err != nil { return true // we exit early if the generated type is not something we can handle } `, + "context": `rng := newRand() + r := rng.Intn(10) + var ctx context.Context + var cancel context.CancelFunc + if r < 5 { + ctx, cancel = context.WithTimeout(context.Background(), 1 * time.Microsecond) + } else { + ctx, cancel = context.WithTimeout(context.Background(), time.Duration(r * 100)*time.Second) + } + defer cancel() +`, } var funcOptCorrect = map[string]string{ @@ -96,7 +111,8 @@ var funcOptCorrect = map[string]string{ "incr": `incr.Memset(identityVal(100, a.t)) correct.Add(incr, UseUnsafe()) `, - "unsafe": "", + "unsafe": "", + "context": "", } var stdTypes = [...]string{ @@ -120,6 +136,27 @@ var stdTypes = [...]string{ "UnsafePointer", } +var arrowBinaryTypes = []string{ + "String", +} + +var arrowFixedWidthTypes = []string{ + "Boolean", +} + +var arrowPrimitiveTypes = []string{ + "Int8", + "Int16", + "Int32", + "Int64", + "Uint8", + "Uint16", + "Uint32", + "Uint64", + "Float32", + "Float64", +} + var parameterizedKinds = [...]reflect.Kind{ reflect.Array, reflect.Chan, @@ -130,6 +167,7 @@ var parameterizedKinds = [...]reflect.Kind{ reflect.Slice, reflect.Struct, } + var number = [...]reflect.Kind{ reflect.Int, reflect.Int8, @@ -405,51 +443,51 @@ func init() { // ops arithBinOps = []arithOp{ - {basicBinOp{"", "Add", false, isAddable}, "numberTypes", true, 0, false, "", true, false}, - {basicBinOp{"", "Sub", false, isNumber}, "numberTypes", false, 0, true, "Add", false, true}, - {basicBinOp{"", "Mul", false, isNumber}, "numberTypes", true, 1, false, "", true, false}, - {basicBinOp{"", "Div", false, isNumber}, "numberTypes", false, 1, true, "Mul", false, false}, - {basicBinOp{"", "Pow", true, isFloatCmplx}, "floatcmplxTypes", true, 1, false, "", false, false}, - {basicBinOp{"", "Mod", false, isNonComplexNumber}, "nonComplexNumberTypes", false, 0, false, "", false, false}, + {basicBinOp{"", "Add", false, isAddable}, "dtype.Number", true, 0, false, "", true, false}, + {basicBinOp{"", "Sub", false, isNumber}, "dtype.Number", false, 0, true, "Add", false, true}, + {basicBinOp{"", "Mul", false, isNumber}, "dtype.Number", true, 1, false, "", true, false}, + {basicBinOp{"", "Div", false, isNumber}, "dtype.Number", false, 1, true, "Mul", false, false}, + {basicBinOp{"", "Pow", true, isFloatCmplx}, "dtype.FloatComplex", true, 1, false, "", false, false}, + {basicBinOp{"", "Mod", false, isNonComplexNumber}, "dtype.NonComplexNumber", false, 0, false, "", false, false}, } for i := range arithBinOps { arithBinOps[i].symbol = arithSymbolTemplates[i] } cmpBinOps = []cmpOp{ - {basicBinOp{"", "Gt", false, isOrd}, "ordTypes", "Lt", true, false}, - {basicBinOp{"", "Gte", false, isOrd}, "ordTypes", "Lte", true, false}, - {basicBinOp{"", "Lt", false, isOrd}, "ordTypes", "Gt", true, false}, - {basicBinOp{"", "Lte", false, isOrd}, "ordTypes", "Gte", true, false}, - {basicBinOp{"", "Eq", false, isEq}, "eqTypes", "Eq", true, true}, - {basicBinOp{"", "Ne", false, isEq}, "eqTypes", "Ne", false, true}, + {basicBinOp{"", "Gt", false, isOrd}, "dtype.Ord", "Lt", true, false}, + {basicBinOp{"", "Gte", false, isOrd}, "dtype.Ord", "Lte", true, false}, + {basicBinOp{"", "Lt", false, isOrd}, "dtype.Ord", "Gt", true, false}, + {basicBinOp{"", "Lte", false, isOrd}, "dtype.Ord", "Gte", true, false}, + {basicBinOp{"", "Eq", false, isEq}, "dtype.Eq", "Eq", true, true}, + {basicBinOp{"", "Ne", false, isEq}, "dtype.Eq", "Ne", false, true}, } for i := range cmpBinOps { cmpBinOps[i].symbol = cmpSymbolTemplates[i] } conditionalUnaries = []unaryOp{ - {"", "Abs", false, isSignedNumber, "signedTypes", ""}, - {"", "Sign", false, isSignedNumber, "signedTypes", ""}, + {"", "Abs", false, isSignedNumber, "dtype.Signed", ""}, + {"", "Sign", false, isSignedNumber, "dtype.Signed", ""}, } for i := range conditionalUnaries { conditionalUnaries[i].symbol = nonFloatConditionalUnarySymbolTemplates[i] } unconditionalUnaries = []unaryOp{ - {"", "Neg", false, isNumber, "numberTypes", "Neg"}, - {"", "Inv", false, isNumber, "numberTypes", ""}, - {"", "Square", false, isNumber, "numberTypes", "Sqrt"}, - {"", "Cube", false, isNumber, "numberTypes", "Cbrt"}, - - {"", "Exp", true, isFloatCmplx, "floatcmplxTypes", "Log"}, - {"", "Tanh", true, isFloatCmplx, "floatcmplxTypes", ""}, - {"", "Log", true, isFloatCmplx, "floatcmplxTypes", "Exp"}, - {"", "Log2", true, isFloat, "floatTypes", ""}, - {"", "Log10", true, isFloatCmplx, "floatcmplxTypes", ""}, - {"", "Sqrt", true, isFloatCmplx, "floatcmplxTypes", "Square"}, - {"", "Cbrt", true, isFloat, "floatTypes", "Cube"}, - {"", "InvSqrt", true, isFloat, "floatTypes", ""}, // TODO: cmplx requires to much finagling to the template. Come back to it later + {"", "Neg", false, isNumber, "dtype.Number", "Neg"}, + {"", "Inv", false, isNumber, "dtype.Number", ""}, + {"", "Square", false, isNumber, "dtype.Number", "Sqrt"}, + {"", "Cube", false, isNumber, "dtype.Number", "Cbrt"}, + + {"", "Exp", true, isFloatCmplx, "dtype.FloatComplex", "Log"}, + {"", "Tanh", true, isFloatCmplx, "dtype.FloatComplex", ""}, + {"", "Log", true, isFloatCmplx, "dtype.FloatComplex", "Exp"}, + {"", "Log2", true, isFloat, "dtype.Floats", ""}, + {"", "Log10", true, isFloatCmplx, "dtype.FloatComplex", ""}, + {"", "Sqrt", true, isFloatCmplx, "dtype.FloatComplex", "Square"}, + {"", "Cbrt", true, isFloat, "dtype.Floats", "Cube"}, + {"", "InvSqrt", true, isFloat, "dtype.Floats", ""}, // TODO: cmplx requires to much finagling to the template. Come back to it later } nonF := len(unconditionalNumUnarySymbolTemplates) for i := range unconditionalNumUnarySymbolTemplates { @@ -460,7 +498,7 @@ func init() { } specialUnaries = []UnaryOp{ - specialUnaryOp{unaryOp{clampBody, "Clamp", false, isNonComplexNumber, "nonComplexNumberTypes", ""}, []string{"min", "max"}}, + specialUnaryOp{unaryOp{clampBody, "Clamp", false, isNonComplexNumber, "dtype.NonComplexNumber", ""}, []string{"min", "max"}}, } // typed operations diff --git a/genlib2/dense_compat.go b/genlib2/dense_compat.go index 2c96721..afb3fda 100644 --- a/genlib2/dense_compat.go +++ b/genlib2/dense_compat.go @@ -5,7 +5,15 @@ import ( "text/template" ) -const conversionsRaw = `func convFromFloat64s(to Dtype, data []float64) interface{} { +const importsArrowRaw = `import ( + arrowArray "github.com/apache/arrow/go/arrow/array" + "github.com/apache/arrow/go/arrow/bitutil" + arrowTensor "github.com/apache/arrow/go/arrow/tensor" + arrow "github.com/apache/arrow/go/arrow" +) +` + +const conversionsRaw = `func convFromFloat64s(to dtype.Dtype, data []float64) interface{} { switch to { {{range .Kinds -}} {{if isNumber . -}} @@ -194,14 +202,12 @@ func FromMat64(m *mat.Dense, opts ...FuncOpt) *Dense { func ToMat64(t *Dense, opts ...FuncOpt) (retVal *mat.Dense, err error) { // checks: if !t.IsNativelyAccessible() { - err = errors.Errorf("Cannot convert *Dense to *mat.Dense. Data is inaccessible") - return + return nil, errors.Errorf("Cannot convert *Dense to *mat.Dense. Data is inaccessible") } if !t.IsMatrix() { // error - err = errors.Errorf("Cannot convert *Dense to *mat.Dense. Expected number of dimensions: <=2, T has got %d dimensions (Shape: %v)", t.Dims(), t.Shape()) - return + return nil, errors.Errorf("Cannot convert *Dense to *mat.Dense. Expected number of dimensions: <=2, T has got %d dimensions (Shape: %v)", t.Dims(), t.Shape()) } fo := ParseFuncOpts(opts...) @@ -214,13 +220,13 @@ func ToMat64(t *Dense, opts ...FuncOpt) (retVal *mat.Dense, err error) { var data []float64 switch { - case t.t == Float64 && toCopy && !t.IsMaterializable(): + case t.t == Float64 && toCopy && !t.RequiresIterator() && t.viewOf == 0: data = make([]float64, t.len()) copy(data, t.Float64s()) - case !t.IsMaterializable(): + case !t.RequiresIterator() && t.viewOf == 0: data = convToFloat64s(t) default: - it := NewFlatIterator(t.AP) + it := newFlatIterator(&t.AP) var next int for next, err = it.Next(); err == nil; next, err = it.Next() { if err = handleNoOp(err); err != nil { @@ -229,25 +235,144 @@ func ToMat64(t *Dense, opts ...FuncOpt) (retVal *mat.Dense, err error) { data = append(data, convToFloat64(t.Get(next))) } err = nil - + } retVal = mat.NewDense(r, c, data) return } + + +` + +type ArrowData struct { + BinaryTypes []string + FixedWidthTypes []string + PrimitiveTypes []string +} + +const compatArrowArrayRaw = `// FromArrowArray converts an "arrow/array".Interface into a Tensor of matching DataType. +func FromArrowArray(a arrowArray.Interface) *Dense { + a.Retain() + defer a.Release() + + r := a.Len() + + // TODO(poopoothegorilla): instead of creating bool ValidMask maybe + // bitmapBytes can be used from arrow API + mask := make([]bool, r) + for i := 0; i < r; i++ { + mask[i] = a.IsNull(i) + } + + switch a.DataType() { + {{range .BinaryTypes -}} + case arrow.BinaryTypes.{{.}}: + {{if eq . "String" -}} + backing := make([]string, r) + for i := 0; i < r; i++ { + backing[i] = a.(*arrowArray.{{.}}).Value(i) + } + {{else -}} + backing := a.(*arrowArray.{{.}}).{{.}}Values() + {{end -}} + retVal := New(WithBacking(backing, mask), WithShape(r, 1)) + return retVal + {{end -}} + {{range .FixedWidthTypes -}} + case arrow.FixedWidthTypes.{{.}}: + {{if eq . "Boolean" -}} + backing := make([]bool, r) + for i := 0; i < r; i++ { + backing[i] = a.(*arrowArray.{{.}}).Value(i) + } + {{else -}} + backing := a.(*arrowArray.{{.}}).{{.}}Values() + {{end -}} + retVal := New(WithBacking(backing, mask), WithShape(r, 1)) + return retVal + {{end -}} + {{range .PrimitiveTypes -}} + case arrow.PrimitiveTypes.{{.}}: + backing := a.(*arrowArray.{{.}}).{{.}}Values() + retVal := New(WithBacking(backing, mask), WithShape(r, 1)) + return retVal + {{end -}} + default: + panic(fmt.Sprintf("Unsupported Arrow DataType - %v", a.DataType())) + } + + panic("Unreachable") +} +` + +const compatArrowTensorRaw = `// FromArrowTensor converts an "arrow/tensor".Interface into a Tensor of matching DataType. +func FromArrowTensor(a arrowTensor.Interface) *Dense { + a.Retain() + defer a.Release() + + if !a.IsContiguous() { + panic("Non-contiguous data is Unsupported") + } + + var shape []int + for _, val := range a.Shape() { + shape = append(shape, int(val)) + } + + l := a.Len() + validMask := a.Data().Buffers()[0].Bytes() + dataOffset := a.Data().Offset() + mask := make([]bool, l) + for i := 0; i < l; i++ { + mask[i] = len(validMask) != 0 && bitutil.BitIsNotSet(validMask, dataOffset+i) + } + + switch a.DataType() { + {{range .PrimitiveTypes -}} + case arrow.PrimitiveTypes.{{.}}: + backing := a.(*arrowTensor.{{.}}).{{.}}Values() + if a.IsColMajor() { + return New(WithShape(shape...), AsFortran(backing, mask)) + } + + return New(WithShape(shape...), WithBacking(backing, mask)) + {{end -}} + default: + panic(fmt.Sprintf("Unsupported Arrow DataType - %v", a.DataType())) + } + + panic("Unreachable") +} ` var ( - conversions *template.Template - compats *template.Template + importsArrow *template.Template + conversions *template.Template + compats *template.Template + compatsArrowArray *template.Template + compatsArrowTensor *template.Template ) func init() { + importsArrow = template.Must(template.New("imports_arrow").Funcs(funcs).Parse(importsArrowRaw)) conversions = template.Must(template.New("conversions").Funcs(funcs).Parse(conversionsRaw)) compats = template.Must(template.New("compat").Funcs(funcs).Parse(compatRaw)) + compatsArrowArray = template.Must(template.New("compat_arrow_array").Funcs(funcs).Parse(compatArrowArrayRaw)) + compatsArrowTensor = template.Must(template.New("compat_arrow_tensor").Funcs(funcs).Parse(compatArrowTensorRaw)) } func generateDenseCompat(f io.Writer, generic Kinds) { + // NOTE(poopoothegorilla): an alias is needed for the Arrow Array pkg to prevent naming + // collisions + importsArrow.Execute(f, generic) conversions.Execute(f, generic) compats.Execute(f, generic) + arrowData := ArrowData{ + BinaryTypes: arrowBinaryTypes, + FixedWidthTypes: arrowFixedWidthTypes, + PrimitiveTypes: arrowPrimitiveTypes, + } + compatsArrowArray.Execute(f, arrowData) + compatsArrowTensor.Execute(f, arrowData) } diff --git a/genlib2/dense_compat_tests.go b/genlib2/dense_compat_tests.go index d2fd049..334f2a8 100644 --- a/genlib2/dense_compat_tests.go +++ b/genlib2/dense_compat_tests.go @@ -1,107 +1,287 @@ -package main - -import ( - "io" - "text/template" -) - -const compatTestsRaw = `var toMat64Tests = []struct{ - data interface{} - sliced interface{} - shape Shape - dt Dtype -}{ - {{range .Kinds -}} - {{if isNumber . -}} - { Range({{asType . | title | strip}}, 0, 6), []{{asType .}}{0,1,3,4}, Shape{2,3}, {{asType . | title | strip}} }, - {{end -}} - {{end -}} -} -func TestToMat64(t *testing.T){ - assert := assert.New(t) - for i, tmt := range toMat64Tests { - T := New(WithBacking(tmt.data), WithShape(tmt.shape...)) - var m *mat.Dense - var err error - if m, err = ToMat64(T); err != nil { - t.Errorf("ToMat basic test %d failed : %v", i, err) - continue - } - conv := anyToFloat64s(tmt.data) - assert.Equal(conv, m.RawMatrix().Data, "i %d from %v", i, tmt.dt) - - if T, err = sliceDense(T, nil, makeRS(0, 2)); err != nil{ - t.Errorf("Slice failed %v", err) - continue - } - if m, err = ToMat64(T); err != nil { - t.Errorf("ToMat of slice test %d failed : %v", i, err) - continue - } - conv = anyToFloat64s(tmt.sliced) - assert.Equal(conv, m.RawMatrix().Data, "sliced test %d from %v", i, tmt.dt) - t.Logf("Done") - - if tmt.dt == Float64 { - T = New(WithBacking(tmt.data), WithShape(tmt.shape...)) - if m, err = ToMat64(T, UseUnsafe()); err != nil { - t.Errorf("ToMat64 unsafe test %d failed: %v", i, err) - } - conv = anyToFloat64s(tmt.data) - assert.Equal(conv, m.RawMatrix().Data, "float64 unsafe i %d from %v", i, tmt.dt) - conv[0] = 1000 - assert.Equal(conv, m.RawMatrix().Data,"float64 unsafe i %d from %v", i, tmt.dt) - conv[0] = 0 // reset for future tests that use the same backing - } - } - // idiocy test - T := New(Of(Float64), WithShape(2,3,4)) - _, err := ToMat64(T) - if err == nil { - t.Error("Expected an error when trying to convert a 3-T to *mat.Dense") - } -} - -func TestFromMat64(t *testing.T){ - assert := assert.New(t) - var m *mat.Dense - var T *Dense - var backing []float64 - - - for i, tmt := range toMat64Tests { - backing = Range(Float64, 0, 6).([]float64) - m = mat.NewDense(2, 3, backing) - T = FromMat64(m) - conv := anyToFloat64s(tmt.data) - assert.Equal(conv, T.Float64s(), "test %d: []float64 from %v", i, tmt.dt) - assert.True(T.Shape().Eq(tmt.shape)) - - T = FromMat64(m, As(tmt.dt)) - assert.Equal(tmt.data, T.Data()) - assert.True(T.Shape().Eq(tmt.shape)) - - if tmt.dt == Float64{ - backing = Range(Float64, 0, 6).([]float64) - m = mat.NewDense(2, 3, backing) - T = FromMat64(m, UseUnsafe()) - assert.Equal(backing, T.Float64s()) - assert.True(T.Shape().Eq(tmt.shape)) - backing[0] = 1000 - assert.Equal(backing, T.Float64s(), "test %d - unsafe float64", i) - } - } -} -` - -var ( - compatTests *template.Template -) - -func init() { - compatTests = template.Must(template.New("testCompat").Funcs(funcs).Parse(compatTestsRaw)) -} - -func generateDenseCompatTests(f io.Writer, generic Kinds) { - compatTests.Execute(f, generic) -} +package main + +import ( + "io" + "text/template" +) + +const compatTestsRaw = `var toMat64Tests = []struct{ + data interface{} + sliced interface{} + shape Shape + dt dtype.Dtype +}{ + {{range .Kinds -}} + {{if isNumber . -}} + { Range({{asType . | title | strip}}, 0, 6), []{{asType .}}{0,1,3,4}, Shape{2,3}, {{asType . | title | strip}} }, + {{end -}} + {{end -}} +} +func TestToMat64(t *testing.T){ + assert := assert.New(t) + for i, tmt := range toMat64Tests { + T := New(WithBacking(tmt.data), WithShape(tmt.shape...)) + var m *mat.Dense + var err error + if m, err = ToMat64(T); err != nil { + t.Errorf("ToMat basic test %d failed : %v", i, err) + continue + } + conv := anyToFloat64s(tmt.data) + assert.Equal(conv, m.RawMatrix().Data, "i %d from %v", i, tmt.dt) + + if T, err = sliceDense(T, nil, makeRS(0, 2)); err != nil{ + t.Errorf("Slice failed %v", err) + continue + } + if m, err = ToMat64(T); err != nil { + t.Errorf("ToMat of slice test %d failed : %v", i, err) + continue + } + conv = anyToFloat64s(tmt.sliced) + assert.Equal(conv, m.RawMatrix().Data, "sliced test %d from %v", i, tmt.dt) + t.Logf("Done") + + if tmt.dt == Float64 { + T = New(WithBacking(tmt.data), WithShape(tmt.shape...)) + if m, err = ToMat64(T, UseUnsafe()); err != nil { + t.Errorf("ToMat64 unsafe test %d failed: %v", i, err) + } + conv = anyToFloat64s(tmt.data) + assert.Equal(conv, m.RawMatrix().Data, "float64 unsafe i %d from %v", i, tmt.dt) + conv[0] = 1000 + assert.Equal(conv, m.RawMatrix().Data,"float64 unsafe i %d from %v", i, tmt.dt) + conv[0] = 0 // reset for future tests that use the same backing + } + } + // idiocy test + T := New(Of(Float64), WithShape(2,3,4)) + _, err := ToMat64(T) + if err == nil { + t.Error("Expected an error when trying to convert a 3-T to *mat.Dense") + } +} + +func TestFromMat64(t *testing.T){ + assert := assert.New(t) + var m *mat.Dense + var T *Dense + var backing []float64 + + + for i, tmt := range toMat64Tests { + backing = Range(Float64, 0, 6).([]float64) + m = mat.NewDense(2, 3, backing) + T = FromMat64(m) + conv := anyToFloat64s(tmt.data) + assert.Equal(conv, T.Float64s(), "test %d: []float64 from %v", i, tmt.dt) + assert.True(T.Shape().Eq(tmt.shape)) + + T = FromMat64(m, As(tmt.dt)) + assert.Equal(tmt.data, T.Data()) + assert.True(T.Shape().Eq(tmt.shape)) + + if tmt.dt == Float64{ + backing = Range(Float64, 0, 6).([]float64) + m = mat.NewDense(2, 3, backing) + T = FromMat64(m, UseUnsafe()) + assert.Equal(backing, T.Float64s()) + assert.True(T.Shape().Eq(tmt.shape)) + backing[0] = 1000 + assert.Equal(backing, T.Float64s(), "test %d - unsafe float64", i) + } + } +} +` + +const compatArrowArrayTestsRaw = `var toArrowArrayTests = []struct{ + data interface{} + valid []bool + dt arrow.DataType + shape Shape +}{ + {{range .PrimitiveTypes -}} + { + data: Range({{.}}, 0, 6), + valid: []bool{true, true, true, false, true, true}, + dt: arrow.PrimitiveTypes.{{ . }}, + shape: Shape{6,1}, + }, + {{end -}} +} +func TestFromArrowArray(t *testing.T){ + assert := assert.New(t) + var T *Dense + pool := memory.NewGoAllocator() + + for i, taat := range toArrowArrayTests { + var m arrowArray.Interface + + switch taat.dt { + {{range .BinaryTypes -}} + case arrow.BinaryTypes.{{ . }}: + b := arrowArray.New{{ . }}Builder(pool) + defer b.Release() + b.AppendValues( + {{if eq . "String" -}} + []string{"0", "1", "2", "3", "4", "5"}, + {{else -}} + Range({{ . }}, 0, 6).([]{{lower . }}), + {{end -}} + taat.valid, + ) + m = b.NewArray() + defer m.Release() + {{end -}} + {{range .FixedWidthTypes -}} + case arrow.FixedWidthTypes.{{ . }}: + b := arrowArray.New{{ . }}Builder(pool) + defer b.Release() + b.AppendValues( + {{if eq . "Boolean" -}} + []bool{true, false, true, false, true, false}, + {{else -}} + Range({{ . }}, 0, 6).([]{{lower . }}), + {{end -}} + taat.valid, + ) + m = b.NewArray() + defer m.Release() + {{end -}} + {{range .PrimitiveTypes -}} + case arrow.PrimitiveTypes.{{ . }}: + b := arrowArray.New{{ . }}Builder(pool) + defer b.Release() + b.AppendValues( + Range({{ . }}, 0, 6).([]{{lower . }}), + taat.valid, + ) + m = b.NewArray() + defer m.Release() + {{end -}} + default: + t.Errorf("DataType not supported in tests: %v", taat.dt) + } + + T = FromArrowArray(m) + switch taat.dt { + {{range .PrimitiveTypes -}} + case arrow.PrimitiveTypes.{{ . }}: + conv := taat.data.([]{{lower . }}) + assert.Equal(conv, T.{{ . }}s(), "test %d: []{{lower . }} from %v", i, taat.dt) + {{end -}} + default: + t.Errorf("DataType not supported in tests: %v", taat.dt) + } + for i, invalid := range T.Mask() { + assert.Equal(taat.valid[i], !invalid) + } + assert.True(T.Shape().Eq(taat.shape)) + } +} +` + +const compatArrowTensorTestsRaw = `var toArrowTensorTests = []struct{ + rowMajorData interface{} + colMajorData interface{} + rowMajorValid []bool + colMajorValid []bool + dt arrow.DataType + shape Shape +}{ + {{range .PrimitiveTypes -}} + { + rowMajorData: []{{lower .}}{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + colMajorData: []{{lower .}}{1, 6, 2, 7, 3, 8, 4, 9, 5, 10}, + rowMajorValid: []bool{true, false, true, false, true, false, true, false, true, false}, + colMajorValid: []bool{true, false, false, true, true, false, false, true, true, false}, + dt: arrow.PrimitiveTypes.{{ . }}, + shape: Shape{2,5}, + }, + {{end -}} +} +func TestFromArrowTensor(t *testing.T){ + assert := assert.New(t) + var rowMajorT *Dense + var colMajorT *Dense + pool := memory.NewGoAllocator() + + for i, taat := range toArrowTensorTests { + var rowMajorArr arrowArray.Interface + var colMajorArr arrowArray.Interface + var rowMajor arrowTensor.Interface + var colMajor arrowTensor.Interface + + switch taat.dt { + {{range .PrimitiveTypes -}} + case arrow.PrimitiveTypes.{{ . }}: + b := arrowArray.New{{ . }}Builder(pool) + defer b.Release() + b.AppendValues( + []{{lower . }}{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + taat.rowMajorValid, + ) + rowMajorArr = b.NewArray() + defer rowMajorArr.Release() + + b.AppendValues( + []{{lower .}}{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + taat.rowMajorValid, + ) + colMajorArr = b.NewArray() + defer colMajorArr.Release() + + rowMajor = arrowTensor.New{{.}}(rowMajorArr.Data(), []int64{2, 5}, nil, []string{"x", "y"}) + defer rowMajor.Release() + colMajor = arrowTensor.New{{.}}(colMajorArr.Data(), []int64{2, 5}, []int64{int64(arrow.{{ . }}SizeBytes), int64(arrow.{{ . }}SizeBytes * 2)}, []string{"x", "y"}) + defer colMajor.Release() + {{end -}} + default: + t.Errorf("DataType not supported in tests: %v", taat.dt) + } + + rowMajorT = FromArrowTensor(rowMajor) + colMajorT = FromArrowTensor(colMajor) + + assert.Equal(taat.rowMajorData, rowMajorT.Data(), "test %d: row major %v", i, taat.dt) + assert.Equal(len(taat.rowMajorValid), len(rowMajorT.Mask()), "test %d: row major %v mask length incorrect", i, taat.dt) + for i, invalid := range rowMajorT.Mask() { + assert.Equal(taat.rowMajorValid[i], !invalid, "test %d: row major %v mask value incorrect", i, taat.dt) + } + assert.True(colMajorT.Shape().Eq(taat.shape)) + + assert.Equal(taat.colMajorData, colMajorT.Data(), "test %d: column major %v", i, taat.dt) + assert.Equal(len(taat.colMajorValid), len(colMajorT.Mask()), "test %d: column major %v mask length incorrect", i, taat.dt) + for i, invalid := range colMajorT.Mask() { + assert.Equal(taat.colMajorValid[i], !invalid, "test %d: column major %v mask value incorrect", i, taat.dt) + } + assert.True(rowMajorT.Shape().Eq(taat.shape)) + } +} +` + +var ( + compatTests *template.Template + compatArrowArrayTests *template.Template + compatArrowTensorTests *template.Template +) + +func init() { + compatTests = template.Must(template.New("testCompat").Funcs(funcs).Parse(compatTestsRaw)) + compatArrowArrayTests = template.Must(template.New("testArrowArrayCompat").Funcs(funcs).Parse(compatArrowArrayTestsRaw)) + compatArrowTensorTests = template.Must(template.New("testArrowTensorCompat").Funcs(funcs).Parse(compatArrowTensorTestsRaw)) +} + +func generateDenseCompatTests(f io.Writer, generic Kinds) { + // NOTE(poopoothegorilla): an alias is needed for the Arrow Array pkg to prevent naming + // collisions + importsArrow.Execute(f, generic) + compatTests.Execute(f, generic) + arrowData := ArrowData{ + BinaryTypes: arrowBinaryTypes, + FixedWidthTypes: arrowFixedWidthTypes, + PrimitiveTypes: arrowPrimitiveTypes, + } + compatArrowArrayTests.Execute(f, arrowData) + compatArrowTensorTests.Execute(f, arrowData) +} diff --git a/genlib2/dense_cons.go b/genlib2/dense_cons.go index 031e73a..aa6bab8 100644 --- a/genlib2/dense_cons.go +++ b/genlib2/dense_cons.go @@ -6,7 +6,7 @@ import ( ) const onesRaw = `// Ones creates a *Dense with the provided shape and type -func Ones(dt Dtype, shape ...int) *Dense { +func Ones(dt dtype.Dtype, shape ...int) *Dense { d := recycledDense(dt, shape) switch d.t.Kind() { {{range .Kinds -}} @@ -48,7 +48,7 @@ const Iraw = `// I creates the identity matrix (usually a square) matrix with 1s // ⎢1 0 0 0⎥ // ⎢0 1 0 0⎥ // ⎣0 0 1 0⎦ -func I(dt Dtype, r, c, k int) *Dense{ +func I(dt dtype.Dtype, r, c, k int) *Dense{ ret := New(Of(dt), WithShape(r,c)) i := k if k < 0 { @@ -68,7 +68,7 @@ func I(dt Dtype, r, c, k int) *Dense{ panic(err) } var nexts []int - iter := NewFlatIterator(s.AP) + iter := newFlatIterator(&s.AP) nexts, err = iter.Slice(rs{i, s.Size(), c + 1}) switch s.t.Kind() { diff --git a/genlib2/dense_cons_tests.go b/genlib2/dense_cons_tests.go index 938d6fa..29d1366 100644 --- a/genlib2/dense_cons_tests.go +++ b/genlib2/dense_cons_tests.go @@ -1,85 +1,85 @@ -package main - -import ( - "io" - "text/template" -) - -const onesTestsRaw = `var onesTests = []struct { - of Dtype - shape Shape - correct interface{} -}{ - {{range .Kinds -}} - {{if isNumber . -}} - { {{asType . | title | strip}}, ScalarShape(), {{asType .}}(1)}, - { {{asType . | title | strip}}, Shape{2,2}, []{{asType .}}{1,1,1,1}}, - {{end -}} - {{end -}} - {Bool, ScalarShape(), true}, - {Bool, Shape{2,2}, []bool{true, true, true, true}}, -} - -func TestOnes(t *testing.T){ - assert := assert.New(t) - for _, ot := range onesTests{ - T := Ones(ot.of, ot.shape...) - assert.True(ot.shape.Eq(T.Shape())) - assert.Equal(ot.correct, T.Data()) - } -} -` - -const eyeTestsRaw = `// yes, it's a pun on eye tests, stop asking and go see your optometrist -var eyeTests = []struct{ - E Dtype - R, C, K int - - - correct interface{} -}{ - {{range .Kinds -}} - {{if isNumber . -}} - { {{asType . | title | strip}}, 4,4, 0, []{{asType .}}{1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1}}, - { {{asType . | title | strip}}, 4,4, 1, []{{asType .}}{0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0}}, - { {{asType . | title | strip}}, 4,4, 2, []{{asType .}}{0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0}}, - { {{asType . | title | strip}}, 4,4, 3, []{{asType .}}{0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}}, - { {{asType . | title | strip}}, 4,4, 4, []{{asType .}}{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}}, - { {{asType . | title | strip}}, 4,4, -1, []{{asType .}}{0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0}}, - { {{asType . | title | strip}}, 4,4, -2, []{{asType .}}{0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0}}, - { {{asType . | title | strip}}, 4,4, -3, []{{asType .}}{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0}}, - { {{asType . | title | strip}}, 4,4, -4, []{{asType .}}{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}}, - { {{asType . | title | strip}}, 4,5, 0, []{{asType .}}{1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0}}, - { {{asType . | title | strip}}, 4,5, 1, []{{asType .}}{0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1}}, - { {{asType . | title | strip}}, 4,5, -1, []{{asType .}}{0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0}}, - {{end -}} - {{end -}} -} - -func TestI(t *testing.T){ - assert := assert.New(t) - var T Tensor - - for i, it := range eyeTests { - T = I(it.E, it.R, it.C, it.K) - assert.True(Shape{it.R, it.C}.Eq(T.Shape())) - assert.Equal(it.correct, T.Data(), "Test %d-R: %d, C: %d K: %d", i, it.R, it.C, it.K) - } - -} -` - -var ( - onesTests *template.Template - eyeTests *template.Template -) - -func init() { - onesTests = template.Must(template.New("onesTest").Funcs(funcs).Parse(onesTestsRaw)) - eyeTests = template.Must(template.New("eyeTest").Funcs(funcs).Parse(eyeTestsRaw)) -} - -func generateDenseConsTests(f io.Writer, generic Kinds) { - onesTests.Execute(f, generic) - eyeTests.Execute(f, generic) -} +package main + +import ( + "io" + "text/template" +) + +const onesTestsRaw = `var onesTests = []struct { + of dtype.Dtype + shape Shape + correct interface{} +}{ + {{range .Kinds -}} + {{if isNumber . -}} + { {{asType . | title | strip}}, ScalarShape(), {{asType .}}(1)}, + { {{asType . | title | strip}}, Shape{2,2}, []{{asType .}}{1,1,1,1}}, + {{end -}} + {{end -}} + {Bool, ScalarShape(), true}, + {Bool, Shape{2,2}, []bool{true, true, true, true}}, +} + +func TestOnes(t *testing.T){ + assert := assert.New(t) + for _, ot := range onesTests{ + T := Ones(ot.of, ot.shape...) + assert.True(ot.shape.Eq(T.Shape())) + assert.Equal(ot.correct, T.Data()) + } +} +` + +const eyeTestsRaw = `// yes, it's a pun on eye tests, stop asking and go see your optometrist +var eyeTests = []struct{ + E dtype.Dtype + R, C, K int + + + correct interface{} +}{ + {{range .Kinds -}} + {{if isNumber . -}} + { {{asType . | title | strip}}, 4,4, 0, []{{asType .}}{1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1}}, + { {{asType . | title | strip}}, 4,4, 1, []{{asType .}}{0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0}}, + { {{asType . | title | strip}}, 4,4, 2, []{{asType .}}{0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0}}, + { {{asType . | title | strip}}, 4,4, 3, []{{asType .}}{0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}}, + { {{asType . | title | strip}}, 4,4, 4, []{{asType .}}{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}}, + { {{asType . | title | strip}}, 4,4, -1, []{{asType .}}{0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0}}, + { {{asType . | title | strip}}, 4,4, -2, []{{asType .}}{0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0}}, + { {{asType . | title | strip}}, 4,4, -3, []{{asType .}}{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0}}, + { {{asType . | title | strip}}, 4,4, -4, []{{asType .}}{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}}, + { {{asType . | title | strip}}, 4,5, 0, []{{asType .}}{1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0}}, + { {{asType . | title | strip}}, 4,5, 1, []{{asType .}}{0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1}}, + { {{asType . | title | strip}}, 4,5, -1, []{{asType .}}{0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0}}, + {{end -}} + {{end -}} +} + +func TestI(t *testing.T){ + assert := assert.New(t) + var T Tensor + + for i, it := range eyeTests { + T = I(it.E, it.R, it.C, it.K) + assert.True(Shape{it.R, it.C}.Eq(T.Shape())) + assert.Equal(it.correct, T.Data(), "Test %d-R: %d, C: %d K: %d", i, it.R, it.C, it.K) + } + +} +` + +var ( + onesTests *template.Template + eyeTests *template.Template +) + +func init() { + onesTests = template.Must(template.New("onesTest").Funcs(funcs).Parse(onesTestsRaw)) + eyeTests = template.Must(template.New("eyeTest").Funcs(funcs).Parse(eyeTestsRaw)) +} + +func generateDenseConsTests(f io.Writer, generic Kinds) { + onesTests.Execute(f, generic) + eyeTests.Execute(f, generic) +} diff --git a/genlib2/dense_getset.go b/genlib2/dense_getset.go index 4d2b415..2f3a38b 100644 --- a/genlib2/dense_getset.go +++ b/genlib2/dense_getset.go @@ -47,10 +47,10 @@ const copyIterRaw = `func copyDenseIter(dest, src *Dense, diter, siter *FlatIter } if diter == nil { - diter = NewFlatIterator(dest.AP) + diter = newFlatIterator(&dest.AP) } if siter == nil { - siter = NewFlatIterator(src.AP) + siter = newFlatIterator(&src.AP) } isMasked:= src.IsMasked() diff --git a/genlib2/dense_getset_tests.go b/genlib2/dense_getset_tests.go index 15cc820..50bafb3 100644 --- a/genlib2/dense_getset_tests.go +++ b/genlib2/dense_getset_tests.go @@ -102,8 +102,8 @@ func makeZeroTests(generic Kinds) []testData { } const getTestRaw = `var denseSetGetTests = []struct { - of Dtype - data interface{} + of dtype.Dtype + data interface{} set interface{} correct []interface{} @@ -129,7 +129,7 @@ func TestDense_setget(t *testing.T) { ` const memsetTestRaw = `var denseMemsetTests = []struct{ - of Dtype + of dtype.Dtype data interface{} val interface{} shape Shape @@ -139,7 +139,7 @@ const memsetTestRaw = `var denseMemsetTests = []struct{ {{range . -}} {{$val := .Set -}} {{$k := .Kind -}} - { {{title .Kind.String | strip}}, []{{asType .Kind}}{ {{range .TestData0 -}}{{printf "%v" .}}, {{end -}} }, {{asType .Kind}}({{$val}}), Shape{2,3}, []{{asType .Kind}}{ {{range .Correct}} {{printf "%v" .}}, {{end -}} } }, + { {{title .Kind.String | strip}}, []{{asType .Kind}}{ {{range .TestData0 -}}{{printf "%v" .}}, {{end -}} }, {{asType .Kind}}({{$val}}), Shape{2,3}, []{{asType .Kind}}{ {{range .Correct}} {{printf "%v" .}}, {{end -}} } }, {{end -}} } @@ -159,7 +159,7 @@ func TestDense_memset(t *testing.T){ ` const zeroTestRaw = `var denseZeroTests = []struct{ - of Dtype + of dtype.Dtype data interface{} correct interface{} @@ -167,18 +167,18 @@ const zeroTestRaw = `var denseZeroTests = []struct{ {{range . -}} {{$val := .Set -}} {{$k := .Kind -}} - { {{title .Kind.String | strip}}, []{{asType .Kind}}{ {{range .TestData0 -}}{{printf "%v" .}}, {{end -}} }, []{{asType .Kind}}{ {{range .Correct}} {{printf "%v" .}}, {{end -}} } }, + { {{title .Kind.String | strip}}, []{{asType .Kind}}{ {{range .TestData0 -}}{{printf "%v" .}}, {{end -}} }, []{{asType .Kind}}{ {{range .Correct}} {{printf "%v" .}}, {{end -}} } }, {{end -}} } func TestDense_Zero(t *testing.T) { assert := assert.New(t) for _, mts := range denseZeroTests { - + typ := reflect.TypeOf(mts.data) val := reflect.ValueOf(mts.data) data := reflect.MakeSlice(typ, val.Len(), val.Cap()) - reflect.Copy(data, val) + reflect.Copy(data, val) T := New(Of(mts.of), WithBacking(data.Interface())) T.Zero() @@ -188,7 +188,7 @@ func TestDense_Zero(t *testing.T) { T2, _ := T.Slice(nil) T2.Zero() assert.Equal(mts.correct, T2.Data()) - } + } } ` diff --git a/genlib2/dense_io.go b/genlib2/dense_io.go index e6e98fe..9b971ac 100644 --- a/genlib2/dense_io.go +++ b/genlib2/dense_io.go @@ -9,21 +9,45 @@ import ( const writeNpyRaw = ` type binaryWriter struct { io.Writer - error + err error seq int } -func (w binaryWriter) w(x interface{}) { - if w.error != nil { +func (w *binaryWriter) w(x interface{}) { + if w.err != nil { return } - binary.Write(w, binary.LittleEndian, x) + w.err = binary.Write(w, binary.LittleEndian, x) w.seq++ } -func (w binaryWriter) Error() string { - return fmt.Sprintf("Error at sequence %d : %v", w.seq, w.error.Error()) +func (w *binaryWriter) Err() error { + if w.err == nil { + return nil + } + return errors.Wrapf(w.err, "Sequence %d", w.seq) +} + +type binaryReader struct { + io.Reader + err error + seq int +} + +func (r *binaryReader) Read(data interface{}) { + if r.err != nil { + return + } + r.err = binary.Read(r.Reader, binary.LittleEndian, data) + r.seq++ +} + +func (r *binaryReader) Err() error { + if r.err == nil { + return nil + } + return errors.Wrapf(r.err, "Sequence %d", r.seq) } // WriteNpy writes the *Tensor as a numpy compatible serialized file. @@ -32,19 +56,26 @@ func (w binaryWriter) Error() string { // http://docs.scipy.org/doc/numpy/neps/npy-format.html // // Gorgonia specifically uses Version 1.0, as 65535 bytes should be more than enough for the headers. -// The values are written in little endian order, because let's face it - +// The values are written in little endian order, because let's face it - // 90% of the world's computers are running on x86+ processors. // // This method does not close the writer. Closing (if needed) is deferred to the caller // If tensor is masked, invalid values are replaced by the default fill value. func (t *Dense) WriteNpy(w io.Writer) (err error) { var npdt string - if npdt, err = t.t.numpyDtype(); err != nil{ - return + if npdt, err = t.t.NumpyDtype(); err != nil{ + return } - header := "{'descr': '<%v', 'fortran_order': False, 'shape': %v}" - header = fmt.Sprintf(header, npdt, t.Shape()) + var header string + if t.Dims() == 1 { + // when t is a 1D vector, numpy expects "(N,)" instead of "(N)" which t.Shape() returns. + header = "{'descr': '<%v', 'fortran_order': False, 'shape': (%d,)}" + header = fmt.Sprintf(header, npdt, t.Shape()[0]) + } else { + header = "{'descr': '<%v', 'fortran_order': False, 'shape': %v}" + header = fmt.Sprintf(header, npdt, t.Shape()) + } padding := 16 - ((10 + len(header)) % 16) if padding > 0 { header = header + strings.Repeat(" ", padding) @@ -54,15 +85,15 @@ func (t *Dense) WriteNpy(w io.Writer) (err error) { bw.w(byte(1)) // major version bw.w(byte(0)) // minor version bw.w(uint16(len(header))) // 4 bytes to denote header length - if bw.error != nil { - return bw + if err = bw.Err() ; err != nil { + return err } bw.Write([]byte(header)) bw.seq = 0 if t.IsMasked(){ fillval:=t.FillValue() - it := FlatMaskedIteratorFromDense(t) + it := FlatMaskedIteratorFromDense(t) for i, err := it.Next(); err == nil; i, err = it.Next() { if t.mask[i] { bw.w(fillval) @@ -76,10 +107,7 @@ func (t *Dense) WriteNpy(w io.Writer) (err error) { } } - if bw.error != nil { - return bw - } - return nil + return bw.Err() } ` @@ -98,7 +126,7 @@ func (t *Dense) WriteCSV(w io.Writer, formats ...string) (err error) { } cw := csv.NewWriter(w) - it := IteratorFromDense(t) + it := IteratorFromDense(t) coord := it.Coord() // rows := t.Shape()[0] @@ -174,7 +202,7 @@ func (t *Dense) GobEncode() (p []byte, err error){ if err = encoder.Encode(&data); err != nil { return } - + return buf.Bytes(), err } ` @@ -184,7 +212,7 @@ func (t *Dense) GobDecode(p []byte) (err error){ buf := bytes.NewBuffer(p) decoder := gob.NewDecoder(buf) - + var shape Shape if err = decoder.Decode(&shape); err != nil { return @@ -203,7 +231,7 @@ func (t *Dense) GobDecode(p []byte) (err error){ } } - t.AP = NewAP(shape, strides) + t.AP.Init(shape, strides) t.AP.o = o t.AP.Δ = tr @@ -211,89 +239,74 @@ func (t *Dense) GobDecode(p []byte) (err error){ if err = decoder.Decode(&mask); err != nil { return } - + var data interface{} if err = decoder.Decode(&data); err != nil { return } + t.fromSlice(data) t.addMask(mask) t.fix() + if t.e == nil { + t.e = StdEng{} + } return t.sanity() } ` +const npyDescRE = `var npyDescRE = regexp.MustCompile(` + "`" + `'descr':` + `\` + `s*'([^']*)'` + "`" + ")" +const rowOrderRE = `var rowOrderRE = regexp.MustCompile(` + "`" + `'fortran_order':\s*(False|True)` + "`)" +const shapeRE = `var shapeRE = regexp.MustCompile(` + "`" + `'shape':\s*\(([^\(]*)\)` + "`)" const readNpyRaw = `// ReadNpy reads NumPy formatted files into a *Dense func (t *Dense) ReadNpy(r io.Reader) (err error){ + br := binaryReader{Reader: r} var magic [6]byte - if _, err = r.Read(magic[:]); err != nil { - return - } - if string(magic[:]) != "\x93NUMPY" { - err = errors.Errorf("Not a numpy file. Got %q as the magic number instead", string(magic[:])) - return + if br.Read(magic[:]); string(magic[:]) != "\x93NUMPY" { + return errors.Errorf("Not a numpy file. Got %q as the magic number instead", string(magic[:])) } - var version byte - if err = binary.Read(r, binary.LittleEndian, &version); err != nil { - return - } - if version != 1 { - err = errors.New("Only verion 1.0 of numpy's serialization format is currently supported (65535 bytes ought to be enough for a header)") - return + var version, minor byte + if br.Read(&version); version != 1 { + return errors.New("Only verion 1.0 of numpy's serialization format is currently supported (65535 bytes ought to be enough for a header)") } - var minor byte - if err = binary.Read(r, binary.LittleEndian, &minor); err != nil { - return - } - if minor != 0 { - err = errors.New("Only verion 1.0 of numpy's serialization format is currently supported (65535 bytes ought to be enough for a header)") - return + if br.Read(&minor); minor != 0 { + return errors.New("Only verion 1.0 of numpy's serialization format is currently supported (65535 bytes ought to be enough for a header)") } var headerLen uint16 - if err = binary.Read(r, binary.LittleEndian, &headerLen); err != nil { - return - } - + br.Read(&headerLen) header := make([]byte, int(headerLen)) - if _, err = r.Read(header); err != nil { + br.Read(header) + if err = br.Err(); err != nil { return } - desc := regexp.MustCompile(` + "`'descr':" + `\s` + "*'([^']*)'`" + `) - match := desc.FindSubmatch(header) - if match == nil { - err = errors.New("No dtype information in npy file") - return + // extract stuff from header + var match [][]byte + if match = npyDescRE.FindSubmatch(header); match == nil { + return errors.New("No dtype information in npy file") } // TODO: check for endianness. For now we assume everything is little endian - var dt Dtype - if dt, err = fromNumpyDtype(string(match[1][1:])); err != nil { + if t.t, err = dtype.FromNumpyDtype(string(match[1][1:])); err != nil { return } - t.t = dt - rowOrder := regexp.MustCompile(` + "`'fortran_order':" + `\s` + "*(False|True)`" + `) - match = rowOrder.FindSubmatch(header) - if match == nil { - err = errors.New("No Row Order information found in the numpy file") - return + if match = rowOrderRE.FindSubmatch(header); match == nil { + return errors.New("No Row Order information found in the numpy file") } if string(match[1]) != "False" { - err = errors.New("Cannot yet read from Fortran Ordered Numpy files") - return + return errors.New("Cannot yet read from Fortran Ordered Numpy files") } - shpRe := regexp.MustCompile(` + "`'shape':" + `\s*\(([^\(]*)\)` + "`" + `) - match = shpRe.FindSubmatch(header) - if match == nil { - err = errors.New("No shape information found in npy file") - return + if match = shapeRE.FindSubmatch(header); match == nil { + return errors.New("No shape information found in npy file") } sizesStr := strings.Split(string(match[1]), ",") + + var shape Shape for _, s := range sizesStr { s = strings.Trim(s, " ") @@ -311,7 +324,6 @@ func (t *Dense) ReadNpy(r io.Reader) (err error){ if t.e == nil { t.e = StdEng{} } - t.makeArray(size) switch t.t.Kind() { @@ -319,21 +331,24 @@ func (t *Dense) ReadNpy(r io.Reader) (err error){ case reflect.{{reflectKind .}}: data := t.{{sliceOf .}} for i := 0; i < size; i++ { - if err = binary.Read(r, binary.LittleEndian, &data[i]); err != nil{ - return - } + br.Read(&data[i]) } {{end -}} } - t.AP = BorrowAP(len(shape)) + if err = br.Err(); err != nil { + return err + } + + t.AP.zeroWithDims(len(shape)) t.setShape(shape...) t.fix() return t.sanity() } ` -const readCSVRaw = `// convFromStrs conversts a []string to a slice of the Dtype provided -func convFromStrs(to Dtype, record []string) (interface{}, error) { +const readCSVRaw = `// convFromStrs converts a []string to a slice of the Dtype provided. It takes a provided backing slice. +// If into is nil, then a backing slice will be created. +func convFromStrs(to dtype.Dtype, record []string, into interface{}) (interface{}, error) { var err error switch to.Kind() { {{range .Kinds -}} @@ -341,6 +356,13 @@ func convFromStrs(to Dtype, record []string) (interface{}, error) { {{if isOrd . -}} case reflect.{{reflectKind .}}: retVal := make([]{{asType .}}, len(record)) + var backing []{{asType .}} + if into == nil { + backing = make([]{{asType .}}, 0, len(record)) + }else{ + backing = into.([]{{asType .}}) + } + for i, v := range record { {{if eq .String "float64" -}} if retVal[i], err = strconv.ParseFloat(v, 64); err != nil { @@ -366,10 +388,20 @@ func convFromStrs(to Dtype, record []string) (interface{}, error) { retVal[i] = {{asType .}}(u) {{end -}} } - return retVal, nil + backing = append(backing, retVal...) + return backing, nil {{end -}} {{end -}} {{end -}} + case reflect.String: + var backing []string + if into == nil { + backing = make([]string, 0, len(record)) + }else{ + backing = into.([]string) + } + backing = append(backing, record...) + return backing, nil default: return nil,errors.Errorf(methodNYI, "convFromStrs", to) } @@ -388,62 +420,221 @@ func (t *Dense) ReadCSV(r io.Reader, opts ...FuncOpt) (err error) { cr := csv.NewReader(r) var record []string - var row interface{} var rows, cols int + var backing interface{} + for { + record, err = cr.Read() + if err == io.EOF{ + break + } else if err != nil { + return + } + if backing, err = convFromStrs(as, record, backing); err != nil { + return + } + cols = len(record) + rows++ + } + t.fromSlice(backing) + t.AP.zero() + t.AP.SetShape(rows, cols) + return nil + return errors.Errorf("not yet handled") +} +` - switch as.Kind() { - {{range .Kinds -}} - {{if isNumber . -}} - {{if isOrd . -}} - case reflect.{{reflectKind .}}: - var backing []{{asType .}} - for { - record, err = cr.Read() - if err == io.EOF{ - break - } +var fbEncodeDecodeRaw = `// FBEncode encodes to a byte slice using flatbuffers. +// +// Only natively accessible data can be encided +func (t *Dense) FBEncode() ([]byte, error) { + builder := flatbuffers.NewBuilder(1024) + + fb.DenseStartShapeVector(builder, len(t.shape)) + for i := len(t.shape) - 1; i >= 0; i-- { + builder.PrependInt32(int32(t.shape[i])) + } + shape := builder.EndVector(len(t.shape)) + + fb.DenseStartStridesVector(builder, len(t.strides)) + for i := len(t.strides) - 1; i >= 0; i-- { + builder.PrependInt32(int32(t.strides[i])) + } + strides := builder.EndVector(len(t.strides)) + + var o uint32 + switch { + case t.o.IsRowMajor() && t.o.IsContiguous(): + o = 0 + case t.o.IsRowMajor() && !t.o.IsContiguous(): + o = 1 + case t.o.IsColMajor() && t.o.IsContiguous(): + o = 2 + case t.o.IsColMajor() && !t.o.IsContiguous(): + o = 3 + } + + var triangle int32 + switch t.Δ { + case NotTriangle: + triangle = fb.TriangleNOT_TRIANGLE + case Upper: + triangle = fb.TriangleUPPER + case Lower: + triangle = fb.TriangleLOWER + case Symmetric: + triangle = fb.TriangleSYMMETRIC + } + + dt := builder.CreateString(t.Dtype().String()) + data := t.byteSlice() + + fb.DenseStartDataVector(builder, len(data)) + for i := len(data) - 1; i >= 0; i-- { + builder.PrependUint8(data[i]) + } + databyte := builder.EndVector(len(data)) + + fb.DenseStart(builder) + fb.DenseAddShape(builder, shape) + fb.DenseAddStrides(builder, strides) + fb.DenseAddO(builder, o) + fb.DenseAddT(builder, triangle) + fb.DenseAddType(builder, dt) + fb.DenseAddData(builder, databyte) + serialized := fb.DenseEnd(builder) + builder.Finish(serialized) + + return builder.FinishedBytes(), nil +} - if err != nil { - return - } +// FBDecode decodes a byteslice from a flatbuffer table into a *Dense +func (t *Dense) FBDecode(buf []byte) error { + serialized := fb.GetRootAsDense(buf, 0) - if row, err = convFromStrs({{asType . | strip | title}}, record); err != nil { - return - } - backing = append(backing, row.([]{{asType .}})...) - cols = len(record) - rows++ - } - t.fromSlice(backing) - t.AP = new(AP) - t.AP.SetShape(rows, cols) - return nil - {{end -}} - {{end -}} - {{end -}} - case reflect.String: - var backing []string - for { - record, err = cr.Read() - if err == io.EOF{ - break - } + o := serialized.O() + switch o { + case 0: + t.o = 0 + case 1: + t.o = MakeDataOrder(NonContiguous) + case 2: + t.o = MakeDataOrder(ColMajor) + case 3: + t.o = MakeDataOrder(ColMajor, NonContiguous) + } - if err != nil { - return - } - backing = append(backing, record...) - cols = len(record) - rows++ - } - t.fromSlice(backing) - t.AP = new(AP) - t.AP.SetShape(rows, cols) - return nil - default: - return errors.Errorf("%v not yet handled", as) + tri := serialized.T() + switch tri { + case fb.TriangleNOT_TRIANGLE: + t.Δ = NotTriangle + case fb.TriangleUPPER: + t.Δ = Upper + case fb.TriangleLOWER: + t.Δ = Lower + case fb.TriangleSYMMETRIC: + t.Δ = Symmetric } - return errors.Errorf("not yet handled") + + t.shape = Shape(BorrowInts(serialized.ShapeLength())) + for i := 0; i < serialized.ShapeLength(); i++ { + t.shape[i] = int(int32(serialized.Shape(i))) + } + + t.strides = BorrowInts(serialized.StridesLength()) + for i := 0; i < serialized.ShapeLength(); i++ { + t.strides[i] = int(serialized.Strides(i)) + } + typ := string(serialized.Type()) + dt, err := dtype.FindByName(typ) + if err != nil { + return errors.Wrap(err, "Failed to decode FlatBuffers") + } + t.t = dt + + if t.e == nil { + t.e = StdEng{} + } + t.makeArray(t.shape.TotalSize()) + + // allocated data. Now time to actually copy over the data + db := t.byteSlice() + copy(db, serialized.DataBytes()) + t.fix() + return t.sanity() +} +` + +var pbEncodeDecodeRaw = `// PBEncode encodes the Dense into a protobuf byte slice. +func (t *Dense) PBEncode() ([]byte, error) { + var toSerialize pb.Dense + toSerialize.Shape = make([]int32, len(t.shape)) + for i, v := range t.shape { + toSerialize.Shape[i] = int32(v) + } + toSerialize.Strides = make([]int32, len(t.strides)) + for i, v := range t.strides { + toSerialize.Strides[i] = int32(v) + } + + switch { + case t.o.IsRowMajor() && t.o.IsContiguous(): + toSerialize.O = pb.RowMajorContiguous + case t.o.IsRowMajor() && !t.o.IsContiguous(): + toSerialize.O = pb.RowMajorNonContiguous + case t.o.IsColMajor() && t.o.IsContiguous(): + toSerialize.O = pb.ColMajorContiguous + case t.o.IsColMajor() && !t.o.IsContiguous(): + toSerialize.O = pb.ColMajorNonContiguous + } + toSerialize.T = pb.Triangle(t.Δ) + toSerialize.Type = t.t.String() + data := t.byteSlice() + toSerialize.Data = make([]byte, len(data)) + copy(toSerialize.Data, data) + return toSerialize.Marshal() +} + +// PBDecode unmarshalls a protobuf byteslice into a *Dense. +func (t *Dense) PBDecode(buf []byte) error { + var toSerialize pb.Dense + if err := toSerialize.Unmarshal(buf); err != nil { + return err + } + t.shape = make(Shape, len(toSerialize.Shape)) + for i, v := range toSerialize.Shape { + t.shape[i] = int(v) + } + t.strides = make([]int, len(toSerialize.Strides)) + for i, v := range toSerialize.Strides { + t.strides[i] = int(v) + } + + switch toSerialize.O { + case pb.RowMajorContiguous: + case pb.RowMajorNonContiguous: + t.o = MakeDataOrder(NonContiguous) + case pb.ColMajorContiguous: + t.o = MakeDataOrder(ColMajor) + case pb.ColMajorNonContiguous: + t.o = MakeDataOrder(ColMajor, NonContiguous) + } + t.Δ = Triangle(toSerialize.T) + typ := string(toSerialize.Type) + dt, err := dtype.FindByName(typ) + if err != nil { + return errors.Wrap(err, "Failed to decode ProtoBuf") + } + t.t = dt + + if t.e == nil { + t.e = StdEng{} + } + t.makeArray(t.shape.TotalSize()) + + // allocated data. Now time to actually copy over the data + db := t.byteSlice() + copy(db, toSerialize.Data) + return t.sanity() } ` @@ -464,15 +655,30 @@ func init() { func generateDenseIO(f io.Writer, generic Kinds) { mk := Kinds{Kinds: filter(generic.Kinds, isNumber)} - // writes - fmt.Fprintln(f, writeNpyRaw) - fmt.Fprint(f, "\n") - fmt.Fprintln(f, writeCSVRaw) - fmt.Fprint(f, "\n") + fmt.Fprint(f, "/* GOB SERIALIZATION */\n\n") gobEncode.Execute(f, mk) + gobDecode.Execute(f, mk) + fmt.Fprint(f, "\n") - // reads + fmt.Fprint(f, "/* NPY SERIALIZATION */\n\n") + fmt.Fprintln(f, npyDescRE) + fmt.Fprintln(f, rowOrderRE) + fmt.Fprintln(f, shapeRE) + f.Write([]byte(writeNpyRaw)) readNpy.Execute(f, mk) - gobDecode.Execute(f, mk) + fmt.Fprint(f, "\n") + + fmt.Fprint(f, "/* CSV SERIALIZATION */\n\n") + f.Write([]byte(writeCSVRaw)) readCSV.Execute(f, mk) + fmt.Fprint(f, "\n") + + fmt.Fprint(f, "/* FB SERIALIZATION */\n\n") + fmt.Fprintln(f, fbEncodeDecodeRaw) + fmt.Fprint(f, "\n") + + fmt.Fprint(f, "/* PB SERIALIZATION */\n\n") + fmt.Fprintln(f, pbEncodeDecodeRaw) + fmt.Fprint(f, "\n") + } diff --git a/genlib2/dense_maskedmethods.go b/genlib2/dense_maskedmethods.go index ce1133c..644e37a 100644 --- a/genlib2/dense_maskedmethods.go +++ b/genlib2/dense_maskedmethods.go @@ -1,103 +1,103 @@ -package main - -import ( - "fmt" - "io" - "reflect" - "text/template" -) - -var maskcmpMethods = []struct { - Name string - Desc string - NumArgs int - CmpFn string - ReqFloat bool - Kinds []reflect.Kind -}{ - {"MaskedEqual", "equal to ", 1, "a == x", false, nil}, - {"MaskedNotEqual", "not equal to ", 1, "a != x", false, nil}, - {"MaskedValues", " equal to ", 3, "math.Abs(float64(a-x)) <= delta", true, nil}, - {"MaskedGreater", " greater than ", 1, "a > x", false, nil}, - {"MaskedGreaterEqual", " greater than or equal to ", 1, "a >= x", false, nil}, - {"MaskedLess", " less than ", 1, "a < x", false, nil}, - {"MaskedLessEqual", " less than or equal to ", 1, "a <= x", false, nil}, - {"MaskedInside", " inside range of ", 2, "(a >= x) && (a <= y)", false, nil}, - {"MaskedOutside", " outside range of ", 2, "(a < x) || (a > y)", false, nil}, -} - -const maskCmpMethodRaw = `// {{.Name}} sets the mask to true where the corresponding data is {{.Desc}} val -// Any values must be the same type as the tensor -func (t *Dense) {{.Name}}({{if ge .NumArgs 1 -}} val1 interface{} {{end}} {{if ge .NumArgs 2 -}} , val2 interface{} {{end}} {{if ge .NumArgs 3 -}} , val3 ...interface{}{{end}})(err error){ - {{if .ReqFloat}} - if !isFloat(t.t) { - err = errors.Errorf("Can only do {{.Name}} with floating point types") - return - } - {{end}} - - if !t.IsMasked() { - t.makeMask() - } - - {{$numargs := .NumArgs}} - {{$name := .Name}} - {{$fn := .CmpFn}} - {{$reqFloat := .ReqFloat}} - switch t.t.Kind(){ - {{range .Kinds -}} - {{if isParameterized . -}} - {{else -}} - {{if or (not (isOrd .)) (and $reqFloat (isntFloat .)) -}} - {{else -}} - case reflect.{{reflectKind .}}: - data := t.{{sliceOf .}} - mask := t.mask - {{if ge $numargs 1 -}} x := val1.({{asType .}}) {{end}} - {{if ge $numargs 2 -}} y := val2.({{asType .}}){{end}} - {{if ge $numargs 3 -}} - {{if eq $name "MaskedValues"}} - delta := float64(1.0e-8) - if len(val3) > 0 { - delta = float64(val3[0].({{asType .}})) + float64(y)*math.Abs(float64(x)) - } - {{else}} - z := val3.({{asType .}}) - {{end}} - {{end}} - if t.maskIsSoft{ - for i := range data { - a := data[i] - mask[i] = ({{$fn}}) - } - } else { - for i := range data { - a := data[i] - mask[i] = mask[i] || ({{$fn}}) - } - } - - {{end}} - {{end}} - {{end}} -} -return nil -} -` - -var ( - maskCmpMethod *template.Template -) - -func init() { - maskCmpMethod = template.Must(template.New("maskcmpmethod").Funcs(funcs).Parse(maskCmpMethodRaw)) -} - -func generateDenseMaskedMethods(f io.Writer, generic Kinds) { - for _, mm := range maskcmpMethods { - mm.Kinds = generic.Kinds - fmt.Fprintf(f, "/* %s */ \n\n", mm.Name) - maskCmpMethod.Execute(f, mm) - - } -} +package main + +import ( + "fmt" + "io" + "reflect" + "text/template" +) + +var maskcmpMethods = []struct { + Name string + Desc string + NumArgs int + CmpFn string + ReqFloat bool + Kinds []reflect.Kind +}{ + {"MaskedEqual", "equal to ", 1, "a == x", false, nil}, + {"MaskedNotEqual", "not equal to ", 1, "a != x", false, nil}, + {"MaskedValues", " equal to ", 3, "math.Abs(float64(a-x)) <= delta", true, nil}, + {"MaskedGreater", " greater than ", 1, "a > x", false, nil}, + {"MaskedGreaterEqual", " greater than or equal to ", 1, "a >= x", false, nil}, + {"MaskedLess", " less than ", 1, "a < x", false, nil}, + {"MaskedLessEqual", " less than or equal to ", 1, "a <= x", false, nil}, + {"MaskedInside", " inside range of ", 2, "(a >= x) && (a <= y)", false, nil}, + {"MaskedOutside", " outside range of ", 2, "(a < x) || (a > y)", false, nil}, +} + +const maskCmpMethodRaw = `// {{.Name}} sets the mask to true where the corresponding data is {{.Desc}} val +// Any values must be the same type as the tensor +func (t *Dense) {{.Name}}({{if ge .NumArgs 1 -}} val1 interface{} {{end}} {{if ge .NumArgs 2 -}} , val2 interface{} {{end}} {{if ge .NumArgs 3 -}} , val3 ...interface{}{{end}})(err error){ + {{if .ReqFloat}} + if !isFloat(t.t) { + err = errors.Errorf("Can only do {{.Name}} with floating point types") + return + } + {{end}} + + if !t.IsMasked() { + t.makeMask() + } + + {{$numargs := .NumArgs}} + {{$name := .Name}} + {{$fn := .CmpFn}} + {{$reqFloat := .ReqFloat}} + switch t.t.Kind(){ + {{range .Kinds -}} + {{if isParameterized . -}} + {{else -}} + {{if or (not (isOrd .)) (and $reqFloat (isntFloat .)) -}} + {{else -}} + case reflect.{{reflectKind .}}: + data := t.{{sliceOf .}} + mask := t.mask + {{if ge $numargs 1 -}} x := val1.({{asType .}}) {{end}} + {{if ge $numargs 2 -}} y := val2.({{asType .}}){{end}} + {{if ge $numargs 3 -}} + {{if eq $name "MaskedValues"}} + delta := float64(1.0e-8) + if len(val3) > 0 { + delta = float64(val3[0].({{asType .}})) + float64(y)*math.Abs(float64(x)) + } + {{else}} + z := val3.({{asType .}}) + {{end}} + {{end}} + if t.maskIsSoft{ + for i := range data { + a := data[i] + mask[i] = ({{$fn}}) + } + } else { + for i := range data { + a := data[i] + mask[i] = mask[i] || ({{$fn}}) + } + } + + {{end}} + {{end}} + {{end}} +} +return nil +} +` + +var ( + maskCmpMethod *template.Template +) + +func init() { + maskCmpMethod = template.Must(template.New("maskcmpmethod").Funcs(funcs).Parse(maskCmpMethodRaw)) +} + +func generateDenseMaskedMethods(f io.Writer, generic Kinds) { + for _, mm := range maskcmpMethods { + mm.Kinds = generic.Kinds + fmt.Fprintf(f, "/* %s */ \n\n", mm.Name) + maskCmpMethod.Execute(f, mm) + + } +} diff --git a/genlib2/dense_maskedmethods_tests.go b/genlib2/dense_maskedmethods_tests.go index addefbd..9b53e1a 100644 --- a/genlib2/dense_maskedmethods_tests.go +++ b/genlib2/dense_maskedmethods_tests.go @@ -18,29 +18,57 @@ const testMaskCmpMethodRaw = `func TestDense_{{title .Name}}_{{short .Kind}}(t * assert.False(T.IsMasked()) data := T.{{sliceOf .Kind}} for i := range data { +{{if eq "string" (asType .Kind) -}} + data[i] = fmt.Sprint(i) +{{else -}} data[i] = {{asType .Kind}}(i) +{{end -}} } +{{if eq "string" (asType .Kind) -}} + T.MaskedEqual(fmt.Sprint(0)) +{{else -}} T.MaskedEqual({{asType .Kind}}(0)) +{{end -}} assert.True(T.IsMasked()) +{{if eq "string" (asType .Kind) -}} + T.MaskedEqual(fmt.Sprint(1)) +{{else -}} T.MaskedEqual({{asType .Kind}}(1)) +{{end -}} assert.True(T.mask[0] && T.mask[1]) +{{if eq "string" (asType .Kind) -}} + T.MaskedNotEqual(fmt.Sprint(2)) +{{else -}} T.MaskedNotEqual({{asType .Kind}}(2)) +{{end -}} assert.False(T.mask[2] && !(T.mask[0])) T.ResetMask() +{{if eq "string" (asType .Kind) -}} + T.MaskedInside(fmt.Sprint(1), fmt.Sprint(22)) +{{else -}} T.MaskedInside({{asType .Kind}}(1), {{asType .Kind}}(22)) +{{end -}} assert.True(!T.mask[0] && !T.mask[23] && T.mask[1] && T.mask[22]) T.ResetMask() +{{if eq "string" (asType .Kind) -}} + T.MaskedOutside(fmt.Sprint(1), fmt.Sprint(22)) +{{else -}} T.MaskedOutside({{asType .Kind}}(1), {{asType .Kind}}(22)) +{{end -}} assert.True(T.mask[0] && T.mask[23] && !T.mask[1] && !T.mask[22]) T.ResetMask() for i := 0; i < 5; i++ { +{{if eq "string" (asType .Kind) -}} + T.MaskedEqual(fmt.Sprint(i*10)) +{{else -}} T.MaskedEqual({{asType .Kind}}(i*10)) +{{end -}} } it := IteratorFromDense(T) - + j := 0 for _, err := it.Next(); err == nil; _, err = it.Next() { j++ diff --git a/genlib2/dense_reduction_methods_tests.go b/genlib2/dense_reduction_methods_tests.go index 62d042f..342b1d4 100644 --- a/genlib2/dense_reduction_methods_tests.go +++ b/genlib2/dense_reduction_methods_tests.go @@ -1,155 +1,164 @@ -package main - -import ( - "fmt" - "io" - "text/template" -) - -const testDenseSumRaw = `var sumTests = []struct { - name string - of Dtype - shape Shape - along []int - - correctShape Shape - correct interface{} -}{ - {{range .Kinds -}} - {{if isNumber . -}} - {"common case: T.Sum() for {{.}}", {{asType . | title}}, Shape{2,3}, []int{}, ScalarShape(), {{asType .}}(15)}, - {"A.Sum(0) for {{.}}", {{asType . | title}}, Shape{2,3}, []int{0}, Shape{3}, []{{asType .}}{3, 5, 7}}, - {"A.Sum(1) for {{.}}", {{asType . | title}}, Shape{2,3},[]int{1}, Shape{2}, []{{asType .}}{3, 12}}, - {"A.Sum(0,1) for {{.}}", {{asType . | title}}, Shape{2,3},[]int{0, 1}, ScalarShape(), {{asType .}}(15)}, - {"A.Sum(1,0) for {{.}}", {{asType . | title}}, Shape{2,3},[]int{1, 0}, ScalarShape(), {{asType .}}(15)}, - {"3T.Sum(1,2) for {{.}}", {{asType . | title}}, Shape{2,3,4}, []int{1,2}, Shape{2}, []{{asType .}}{66, {{if eq .String "int8"}}-46{{else}}210{{end}} }}, - {{end -}} - {{end -}} -} -func TestDense_Sum(t *testing.T){ - assert := assert.New(t) - var T, T2 *Dense - var err error - - for _, sts := range sumTests { - T = New(WithShape(sts.shape...), WithBacking(Range(sts.of, 0, sts.shape.TotalSize()))) - if T2, err = T.Sum(sts.along ...); err != nil { - t.Error(err) - continue - } - assert.True(sts.correctShape.Eq(T2.Shape())) - assert.Equal(sts.correct, T2.Data()) - } - - // idiots - _,err =T.Sum(1000) - assert.NotNil(err) -} -` - -const testDenseMaxRaw = `var maxTests = []struct { - name string - of Dtype - shape Shape - along []int - - correctShape Shape - correct interface{} -}{ - {{range .Kinds -}} - {{if isNumber . -}} - {{if isOrd . -}} - {"common case: T.Max() for {{.}}", {{asType . | title}}, Shape{2,3}, []int{}, ScalarShape(), {{asType .}}(5)}, - {"A.Max(0)", {{asType . | title}}, Shape{2,3},[]int{0}, Shape{3}, []{{asType . }}{3, 4, 5}}, - {"A.Max(1)", {{asType . | title}}, Shape{2,3},[]int{1}, Shape{2}, []{{asType . }}{2,5}}, - {"A.Max(0,1)", {{asType . | title}}, Shape{2,3},[]int{0, 1}, ScalarShape(), {{asType .}}(5)}, - {"A.Max(1,0)", {{asType . | title}}, Shape{2,3},[]int{1, 0}, ScalarShape(), {{asType .}}(5)}, - {"3T.Max(1,2)", {{asType . | title}}, Shape{2,3,4}, []int{1,2}, Shape{2}, []{{asType .}}{11, 23} }, - {{end -}} - {{end -}} - {{end -}} -} - -func TestDense_Max(t *testing.T){ - assert := assert.New(t) - var T, T2 *Dense - var err error - - for _, mts := range maxTests { - T = New(WithShape(mts.shape...), WithBacking(Range(mts.of, 0, mts.shape.TotalSize()))) - if T2, err = T.Max(mts.along...); err != nil{ - t.Error(err) - continue - } - assert.True(mts.correctShape.Eq(T2.Shape())) - assert.Equal(mts.correct, T2.Data()) - } - /* IDIOT TESTING TIME */ - _, err = T.Max(1000) - assert.NotNil(err) -} -` - -const testDenseMinRaw = `var minTests = []struct { - name string - of Dtype - shape Shape - along []int - - correctShape Shape - correct interface{} -}{ - {{range .Kinds -}} - {{if isNumber . -}} - {{if isOrd . -}} - {"common case: T.Min() for {{.}}", {{asType .|title}}, Shape{2,3}, []int{}, ScalarShape(), {{asType .}}(0)}, - {"A.Min(0)", {{asType .|title}}, Shape{2,3}, []int{0}, Shape{3}, []{{asType .}}{0, 1, 2}}, - {"A.Min(1)", {{asType .|title}}, Shape{2,3}, []int{1}, Shape{2}, []{{asType .}}{0, 3}}, - {"A.Min(0,1)", {{asType .|title}}, Shape{2,3}, []int{0, 1}, ScalarShape(), {{asType .}}(0)}, - {"A.Min(1,0)", {{asType .|title}}, Shape{2,3}, []int{1, 0}, ScalarShape(), {{asType .}}(0)}, - {"3T.Min(1,2)", {{asType . | title}}, Shape{2,3,4}, []int{1,2}, Shape{2}, []{{asType .}}{0,12} }, - {{end -}} - {{end -}} - {{end -}} -} - -func TestDense_Min(t *testing.T){ - assert := assert.New(t) - var T, T2 *Dense - var err error - - for _, mts := range minTests { - T = New(WithShape(mts.shape...), WithBacking(Range(mts.of, 0, mts.shape.TotalSize()))) - if T2, err = T.Min(mts.along...); err != nil{ - t.Error(err) - continue - } - assert.True(mts.correctShape.Eq(T2.Shape())) - assert.Equal(mts.correct, T2.Data()) - } - - /* IDIOT TESTING TIME */ - _, err = T.Min(1000) - assert.NotNil(err) -} -` - -var ( - testDenseSum *template.Template - testDenseMax *template.Template - testDenseMin *template.Template -) - -func init() { - testDenseSum = template.Must(template.New("testDenseSum").Funcs(funcs).Parse(testDenseSumRaw)) - testDenseMax = template.Must(template.New("testDenseMax").Funcs(funcs).Parse(testDenseMaxRaw)) - testDenseMin = template.Must(template.New("testDenseMin").Funcs(funcs).Parse(testDenseMinRaw)) -} - -func generateDenseReductionMethodsTests(f io.Writer, generic Kinds) { - testDenseSum.Execute(f, generic) - fmt.Fprint(f, "\n") - testDenseMax.Execute(f, generic) - fmt.Fprint(f, "\n") - testDenseMin.Execute(f, generic) -} +package main + +import ( + "fmt" + "io" + "text/template" +) + +const testDenseSumRaw = `var sumTests = []struct { + name string + of dtype.Dtype + shape Shape + along []int + + correctShape Shape + correct interface{} +}{ + {{range .Kinds -}} + {{if isNumber . -}} + {"common case: T.Sum() for {{.}}", {{asType . | title}}, Shape{2,3}, []int{}, ScalarShape(), {{asType .}}(15)}, + {"A.Sum(0) for {{.}}", {{asType . | title}}, Shape{2,3}, []int{0}, Shape{3}, []{{asType .}}{3, 5, 7}}, + {"A.Sum(1) for {{.}}", {{asType . | title}}, Shape{2,3},[]int{1}, Shape{2}, []{{asType .}}{3, 12}}, + {"A.Sum(0,1) for {{.}}", {{asType . | title}}, Shape{2,3},[]int{0, 1}, ScalarShape(), {{asType .}}(15)}, + {"A.Sum(1,0) for {{.}}", {{asType . | title}}, Shape{2,3},[]int{1, 0}, ScalarShape(), {{asType .}}(15)}, + {"3T.Sum(1,2) for {{.}}", {{asType . | title}}, Shape{2,3,4}, []int{1,2}, Shape{2}, []{{asType .}}{66, {{if eq .String "int8"}}-46{{else}}210{{end}} }}, + {"4T.Sum() for {{.}}", {{asType . | title}}, Shape{2, 2, 2, 2},[]int{}, ScalarShape(), {{asType .}}(120)}, + {"4T.Sum(1,3) for {{.}}", {{asType . | title}}, Shape{2, 2, 2, 2}, []int{1, 3}, Shape{2, 2}, []{{asType .}}{10, 18, 42, 50}}, + {"4T.Sum(0, 2, 3) for {{.}}", {{asType . | title}}, Shape{2, 2, 2, 2}, []int{0, 2, 3}, Shape{2}, []{{asType .}}{44, 76}}, + {{end -}} + {{end -}} +} +func TestDense_Sum(t *testing.T){ + assert := assert.New(t) + var T, T2 *Dense + var err error + + for _, sts := range sumTests { + T = New(WithShape(sts.shape...), WithBacking(Range(sts.of, 0, sts.shape.TotalSize()))) + if T2, err = T.Sum(sts.along ...); err != nil { + t.Error(err) + continue + } + assert.True(sts.correctShape.Eq(T2.Shape())) + assert.Equal(sts.correct, T2.Data()) + } + + // idiots + _,err =T.Sum(1000) + assert.NotNil(err) +} +` + +const testDenseMaxRaw = `var maxTests = []struct { + name string + of dtype.Dtype + shape Shape + along []int + + correctShape Shape + correct interface{} +}{ + {{range .Kinds -}} + {{if isNumber . -}} + {{if isOrd . -}} + {"common case: T.Max() for {{.}}", {{asType . | title}}, Shape{2,3}, []int{}, ScalarShape(), {{asType .}}(5)}, + {"A.Max(0)", {{asType . | title}}, Shape{2,3},[]int{0}, Shape{3}, []{{asType . }}{3, 4, 5}}, + {"A.Max(1)", {{asType . | title}}, Shape{2,3},[]int{1}, Shape{2}, []{{asType . }}{2,5}}, + {"A.Max(0,1)", {{asType . | title}}, Shape{2,3},[]int{0, 1}, ScalarShape(), {{asType .}}(5)}, + {"A.Max(1,0)", {{asType . | title}}, Shape{2,3},[]int{1, 0}, ScalarShape(), {{asType .}}(5)}, + {"3T.Max(1,2)", {{asType . | title}}, Shape{2,3,4}, []int{1,2}, Shape{2}, []{{asType .}}{11, 23} }, + {"4T.Max()", {{asType . | title}}, Shape{2, 2, 2, 2},[]int{}, ScalarShape(), {{asType .}}(15)}, + {"4T.Max(1,3)", {{asType . | title}}, Shape{2, 2, 2, 2}, []int{1, 3}, Shape{2, 2}, []{{asType .}}{5, 7, 13, 15}}, + {"4T.Max(0, 2, 3)", {{asType . | title}}, Shape{2, 2, 2, 2}, []int{0, 2, 3}, Shape{2}, []{{asType .}}{11, 15}}, + {{end -}} + {{end -}} + {{end -}} +} + +func TestDense_Max(t *testing.T){ + assert := assert.New(t) + var T, T2 *Dense + var err error + + for _, mts := range maxTests { + T = New(WithShape(mts.shape...), WithBacking(Range(mts.of, 0, mts.shape.TotalSize()))) + if T2, err = T.Max(mts.along...); err != nil{ + t.Error(err) + continue + } + assert.True(mts.correctShape.Eq(T2.Shape())) + assert.Equal(mts.correct, T2.Data()) + } + /* IDIOT TESTING TIME */ + _, err = T.Max(1000) + assert.NotNil(err) +} +` + +const testDenseMinRaw = `var minTests = []struct { + name string + of dtype.Dtype + shape Shape + along []int + + correctShape Shape + correct interface{} +}{ + {{range .Kinds -}} + {{if isNumber . -}} + {{if isOrd . -}} + {"common case: T.Min() for {{.}}", {{asType .|title}}, Shape{2,3}, []int{}, ScalarShape(), {{asType .}}(0)}, + {"A.Min(0)", {{asType .|title}}, Shape{2,3}, []int{0}, Shape{3}, []{{asType .}}{0, 1, 2}}, + {"A.Min(1)", {{asType .|title}}, Shape{2,3}, []int{1}, Shape{2}, []{{asType .}}{0, 3}}, + {"A.Min(0,1)", {{asType .|title}}, Shape{2,3}, []int{0, 1}, ScalarShape(), {{asType .}}(0)}, + {"A.Min(1,0)", {{asType .|title}}, Shape{2,3}, []int{1, 0}, ScalarShape(), {{asType .}}(0)}, + {"3T.Min(1,2)", {{asType . | title}}, Shape{2,3,4}, []int{1,2}, Shape{2}, []{{asType .}}{0,12} }, + {"4T.Min()", {{asType . | title}}, Shape{2, 2, 2, 2},[]int{}, ScalarShape(), {{asType .}}(0)}, + {"4T.Min(1,3)", {{asType . | title}}, Shape{2, 2, 2, 2}, []int{1, 3}, Shape{2, 2}, []{{asType .}}{0, 2, 8, 10}}, + {"4T.Min(0, 2, 3)", {{asType . | title}}, Shape{2, 2, 2, 2}, []int{0, 2, 3}, Shape{2}, []{{asType .}}{0, 4}}, + {{end -}} + {{end -}} + {{end -}} +} + +func TestDense_Min(t *testing.T){ + assert := assert.New(t) + var T, T2 *Dense + var err error + + for _, mts := range minTests { + T = New(WithShape(mts.shape...), WithBacking(Range(mts.of, 0, mts.shape.TotalSize()))) + if T2, err = T.Min(mts.along...); err != nil{ + t.Error(err) + continue + } + assert.True(mts.correctShape.Eq(T2.Shape())) + assert.Equal(mts.correct, T2.Data()) + } + + /* IDIOT TESTING TIME */ + _, err = T.Min(1000) + assert.NotNil(err) +} +` + +var ( + testDenseSum *template.Template + testDenseMax *template.Template + testDenseMin *template.Template +) + +func init() { + testDenseSum = template.Must(template.New("testDenseSum").Funcs(funcs).Parse(testDenseSumRaw)) + testDenseMax = template.Must(template.New("testDenseMax").Funcs(funcs).Parse(testDenseMaxRaw)) + testDenseMin = template.Must(template.New("testDenseMin").Funcs(funcs).Parse(testDenseMinRaw)) +} + +func generateDenseReductionMethodsTests(f io.Writer, generic Kinds) { + testDenseSum.Execute(f, generic) + fmt.Fprint(f, "\n") + testDenseMax.Execute(f, generic) + fmt.Fprint(f, "\n") + testDenseMin.Execute(f, generic) +} diff --git a/genlib2/dense_reduction_tests.go b/genlib2/dense_reduction_tests.go index 2c35efa..06f78c0 100644 --- a/genlib2/dense_reduction_tests.go +++ b/genlib2/dense_reduction_tests.go @@ -6,7 +6,7 @@ import ( ) const testDenseReduceRaw = `var denseReductionTests = []struct { - of Dtype + of dtype.Dtype fn interface{} def interface{} axis int diff --git a/genlib2/engine.go b/genlib2/engine.go index 8792e2e..6f551fc 100644 --- a/genlib2/engine.go +++ b/genlib2/engine.go @@ -7,10 +7,12 @@ import ( ) type EngineArith struct { + isStdDenseEng bool Name string VecVar string PrepData string TypeClassCheck string + IsCommutative bool VV bool LeftVec bool @@ -33,9 +35,11 @@ func (fn *EngineArith) Signature() *Signature { case fn.VV: paramNames = []string{"a", "b", "opts"} paramTemplates = []*template.Template{tensorType, tensorType, splatFuncOptType} + default: paramNames = []string{"t", "s", "leftTensor", "opts"} paramTemplates = []*template.Template{tensorType, interfaceType, boolType, splatFuncOptType} + } return &Signature{ Name: fn.methName(), @@ -97,6 +101,7 @@ func generateStdEngArith(f io.Writer, ak Kinds) { Name: abo.Name(), VV: true, TypeClassCheck: "Number", + IsCommutative: abo.IsCommutative, } methods = append(methods, meth) } @@ -234,6 +239,123 @@ func generateStdEngCmp(f io.Writer, ak Kinds) { } } +type EngineMinMax struct { + Name string + VecVar string + PrepData string + TypeClassCheck string + Kinds []reflect.Kind + + VV bool + LeftVec bool +} + +func (fn *EngineMinMax) methName() string { + switch { + case fn.VV: + return fn.Name + default: + return fn.Name + "Scalar" + } +} + +func (fn *EngineMinMax) Signature() *Signature { + var paramNames []string + var paramTemplates []*template.Template + + switch { + case fn.VV: + paramNames = []string{"a", "b", "opts"} + paramTemplates = []*template.Template{tensorType, tensorType, splatFuncOptType} + default: + paramNames = []string{"t", "s", "leftTensor", "opts"} + paramTemplates = []*template.Template{tensorType, interfaceType, boolType, splatFuncOptType} + } + return &Signature{ + Name: fn.methName(), + NameTemplate: plainName, + ParamNames: paramNames, + ParamTemplates: paramTemplates, + Err: false, + } +} + +func (fn *EngineMinMax) WriteBody(w io.Writer) { + var prep *template.Template + switch { + case fn.VV: + prep = prepVV + fn.VecVar = "a" + case !fn.VV && fn.LeftVec: + fn.VecVar = "t" + fn.PrepData = "prepDataVS" + prep = prepMixed + default: + fn.VecVar = "t" + fn.PrepData = "prepDataSV" + prep = prepMixed + } + template.Must(prep.New("prep").Parse(minmaxPrepRaw)) + prep.Execute(w, fn) + agg2MinMaxBody.Execute(w, fn) +} + +func (fn *EngineMinMax) Write(w io.Writer) { + if tmpl, ok := cmpDocStrings[fn.methName()]; ok { + type tmp struct { + Left, Right string + } + var ds tmp + if fn.VV { + ds.Left = "a" + ds.Right = "b" + } else { + ds.Left = "t" + ds.Right = "s" + } + tmpl.Execute(w, ds) + } + sig := fn.Signature() + w.Write([]byte("func (e StdEng) ")) + sig.Write(w) + w.Write([]byte("(retVal Tensor, err error) {\n")) + fn.WriteBody(w) + w.Write([]byte("}\n\n")) +} + +func generateStdEngMinMax(f io.Writer, ak Kinds) { + methods := []*EngineMinMax{ + &EngineMinMax{ + Name: "MinBetween", + VV: true, + TypeClassCheck: "Ord", + }, + &EngineMinMax{ + Name: "MaxBetween", + VV: true, + TypeClassCheck: "Ord", + }, + } + f.Write([]byte(`var ( + _ MinBetweener = StdEng{} + _ MaxBetweener = StdEng{} +) +`)) + // VV + for _, meth := range methods { + meth.Write(f) + meth.VV = false + } + + // Scalar-Vector + for _, meth := range methods { + meth.Write(f) + meth.LeftVec = true + } +} + +/* UNARY METHODS */ + type EngineUnary struct { Name string TypeClassCheck string @@ -269,18 +391,18 @@ func (fn *EngineUnary) Write(w io.Writer) { func generateStdEngUncondUnary(f io.Writer, ak Kinds) { tcc := []string{ - "Number", // Neg - "Number", // Inv - "Number", // Square - "Number", // Cube - "FloatCmplx", // Exp - "FloatCmplx", // Tanhh - "FloatCmplx", // Log - "Float", // Log2 - "FloatCmplx", // Log10 - "FloatCmplx", // Sqrt - "Float", // Cbrt - "Float", // InvSqrt + "Number", // Neg + "Number", // Inv + "Number", // Square + "Number", // Cube + "FloatComplex", // Exp + "FloatComplex", // Tanhh + "FloatComplex", // Log + "Floats", // Log2 + "FloatComplex", // Log10 + "FloatComplex", // Sqrt + "Floats", // Cbrt + "Floats", // InvSqrt } var gen []*EngineUnary for i, u := range unconditionalUnaries { diff --git a/genlib2/generic_arith.go b/genlib2/generic_arith.go index fa6fa39..a0ab358 100644 --- a/genlib2/generic_arith.go +++ b/genlib2/generic_arith.go @@ -11,6 +11,7 @@ type GenericVecVecArith struct { TypedBinOp Iter bool Incr bool + WithRecv bool // not many BinOps have this Check TypeClass // can be nil CheckTemplate string } @@ -23,6 +24,8 @@ func (fn *GenericVecVecArith) Name() string { return fmt.Sprintf("%sIter", fn.TypedBinOp.Name()) case !fn.Iter && fn.Incr: return fmt.Sprintf("%sIncr", fn.TypedBinOp.Name()) + case fn.WithRecv: + return fmt.Sprintf("%vRecv", fn.TypedBinOp.Name()) default: return fmt.Sprintf("Vec%s", fn.TypedBinOp.Name()) } @@ -45,6 +48,9 @@ func (fn *GenericVecVecArith) Signature() *Signature { case !fn.Iter && fn.Incr: paramNames = []string{"a", "b", "incr"} paramTemplates = []*template.Template{sliceType, sliceType, sliceType} + case fn.WithRecv: + paramNames = []string{"a", "b", "recv"} + paramTemplates = []*template.Template{sliceType, sliceType, sliceType} default: paramNames = []string{"a", "b"} paramTemplates = []*template.Template{sliceType, sliceType} @@ -97,6 +103,11 @@ func (fn *GenericVecVecArith) WriteBody(w io.Writer) { Right = "b[i]" T = template.Must(T.Parse(genericLoopRaw)) template.Must(T.New("loopbody").Parse(basicIncr)) + case fn.WithRecv: + Range = "recv" + Right = "b[i]" + T = template.Must(T.Parse(genericLoopRaw)) + template.Must(T.New("loopbody").Parse(basicSet)) default: Right = "b[i]" T = template.Must(T.Parse(genericLoopRaw)) @@ -130,7 +141,7 @@ func (fn *GenericVecVecArith) WriteBody(w io.Writer) { func (fn *GenericVecVecArith) Write(w io.Writer) { sig := fn.Signature() - if !fn.Iter && isFloat(fn.Kind()) { + if !fn.Iter && isFloat(fn.Kind()) && !fn.WithRecv { // golinkPragma.Execute(w, fn) w.Write([]byte("func ")) sig.Write(w) @@ -148,7 +159,9 @@ func (fn *GenericVecVecArith) Write(w io.Writer) { switch { case !fn.Iter && fn.Incr: w.Write([]byte("{\na = a[:len(a)]; b = b[:len(a)]; incr = incr[:len(a)]\n")) - case !fn.Iter && !fn.Incr: + case fn.WithRecv: + w.Write([]byte("{\na = a[:len(recv)]; b = b[:len(recv)]\n")) + case !fn.Iter && !fn.Incr && !fn.WithRecv: w.Write([]byte("{\na = a[:len(a)]; b = b[:len(a)]\n")) default: w.Write([]byte("{\n")) @@ -390,6 +403,7 @@ func makeGenericVecVecAriths(tbo []TypedBinOp) (retVal []*GenericVecVecArith) { fn.Check = panicsDiv0 fn.CheckTemplate = check0 } + retVal = append(retVal, fn) } @@ -457,6 +471,12 @@ func generateGenericVecVecArith(f io.Writer, ak Kinds) { for _, g := range gen { g.Write(f) } + for _, g := range gen { + g.Incr = false + g.Iter = false + g.WithRecv = true + g.Write(f) + } } func generateGenericMixedArith(f io.Writer, ak Kinds) { diff --git a/genlib2/generic_cmp.go b/genlib2/generic_cmp.go index e008e11..fecf360 100644 --- a/genlib2/generic_cmp.go +++ b/genlib2/generic_cmp.go @@ -394,6 +394,23 @@ const genericElMinMaxRaw = `func VecMin{{short . | title}}(a, b []{{asType .}}) } } } + +func MinSV{{short . | title}}(a {{asType .}}, b []{{asType .}}){ + for i := range b { + if a < b[i]{ + b[i] = a + } + } +} + +func MinVS{{short . | title}}(a []{{asType .}}, b {{asType .}}){ + for i := range a { + if b < a[i]{ + a[i] = b + } + } +} + func VecMax{{short . | title}}(a, b []{{asType .}}) { a = a[:len(a)] b = b[:len(a)] @@ -404,6 +421,138 @@ func VecMax{{short . | title}}(a, b []{{asType .}}) { } } } + + + +func MaxSV{{short . | title}}(a {{asType .}}, b []{{asType .}}){ + for i := range b { + if a > b[i]{ + b[i] = a + } + } +} + +func MaxVS{{short . | title}}(a []{{asType .}}, b {{asType .}}){ + for i := range a { + if b > a[i]{ + a[i] = b + } + } +} +` + +// Iter Min/Max +const genericIterMinMaxRaw = `func MinIterSV{{short . | title}}(a {{asType .}}, b []{{asType .}}, bit Iterator) (err error){ + var i int + var validi bool + for { + if i, validi, err = bit.NextValidity(); err != nil{ + err = handleNoOp(err) + break + } + if validi { + if a < b[i] { + b[i] = a + } + } + } + return +} + +func MinIterVS{{short . | title}}(a []{{asType .}}, b {{asType .}}, ait Iterator) (err error){ + var i int + var validi bool + for { + if i, validi, err = ait.NextValidity(); err != nil{ + err = handleNoOp(err) + break + } + if validi { + if b < a[i] { + a[i] = b + } + } + } + return +} + +func VecMinIter{{short . | title}}(a , b []{{asType .}}, ait, bit Iterator) (err error){ + var i,j int + var validi ,validj bool + for { + if i, validi, err = ait.NextValidity(); err != nil{ + err = handleNoOp(err) + break + } + if j, validj, err = bit.NextValidity(); err != nil{ + err = handleNoOp(err) + break + } + if validi && validj { + if b[j] < a[i] { + a[i] = b[j] + } + } + } + return +} + + +func MaxIterSV{{short . | title}}(a {{asType .}}, b []{{asType .}}, bit Iterator) (err error){ + var i int + var validi bool + for { + if i, validi, err = bit.NextValidity(); err != nil{ + err = handleNoOp(err) + break + } + if validi { + if a > b[i] { + b[i] = a + } + } + } + return +} + +func MaxIterVS{{short . | title}}(a []{{asType .}}, b {{asType .}}, ait Iterator) (err error){ + var i int + var validi bool + for { + if i, validi, err = ait.NextValidity(); err != nil{ + err = handleNoOp(err) + break + } + if validi { + if b > a[i] { + a[i] = b + } + } + } + return +} + +func VecMaxIter{{short . | title}}(a , b []{{asType .}}, ait, bit Iterator) (err error){ + var i,j int + var validi, validj bool + for { + if i, validi, err = ait.NextValidity(); err != nil{ + err = handleNoOp(err) + break + } + if j, validj, err = bit.NextValidity(); err != nil{ + err = handleNoOp(err) + break + } + if validi && validj { + if b[j] > a[i] { + a[i] = b[j] + } + } + } + return +} + ` // scalar Min/Max @@ -413,6 +562,7 @@ const genericScalarMinMaxRaw = `func Min{{short .}}(a, b {{asType .}}) (c {{asTy return b } + func Max{{short .}}(a, b {{asType .}}) (c {{asType .}}) {if a > b { return a } @@ -421,13 +571,15 @@ func Max{{short .}}(a, b {{asType .}}) (c {{asType .}}) {if a > b { ` var ( - genericElMinMax *template.Template - genericMinMax *template.Template + genericElMinMax *template.Template + genericMinMax *template.Template + genericElMinMaxIter *template.Template ) func init() { genericElMinMax = template.Must(template.New("genericVecVecMinMax").Funcs(funcs).Parse(genericElMinMaxRaw)) genericMinMax = template.Must(template.New("genericMinMax").Funcs(funcs).Parse(genericScalarMinMaxRaw)) + genericElMinMaxIter = template.Must(template.New("genericIterMinMax").Funcs(funcs).Parse(genericIterMinMaxRaw)) } func generateMinMax(f io.Writer, ak Kinds) { @@ -438,4 +590,8 @@ func generateMinMax(f io.Writer, ak Kinds) { for _, k := range filter(ak.Kinds, isOrd) { genericMinMax.Execute(f, k) } + + for _, k := range filter(ak.Kinds, isOrd) { + genericElMinMaxIter.Execute(f, k) + } } diff --git a/genlib2/generic_utils.go b/genlib2/generic_utils.go index 7c207fa..8d5f87b 100644 --- a/genlib2/generic_utils.go +++ b/genlib2/generic_utils.go @@ -8,7 +8,7 @@ import ( const rangeRaw = `// Range creates a ranged array with a given type. It panics if the Dtype is not supported or does not represent a naturally orderable type (strings, pointers etc) // Do note that the range algorithm is very simple, and simply does increments or decrements of 1. This means for floating point types // you're not able to create a range with a 0.1 increment step, and for complex number types, the imaginary part will always be 0i -func Range(dt Dtype, start, end int) interface{} { +func Range(dt dtype.Dtype, start, end int) interface{} { size := end - start incr := true if start > end { @@ -58,8 +58,8 @@ func Range(dt Dtype, start, end int) interface{} { const randomRaw = `// Random creates an array of random numbers of the given type. // For complex Dtypes, the imaginary component will be 0. // -// This function is only useful in cases where the randomness is not vital. -func Random(dt Dtype, size int) interface{} { +// This function is only useful in cases where the randomness is not vital. +func Random(dt dtype.Dtype, size int) interface{} { r := rand.New(rand.NewSource(1337)) switch dt.Kind() { {{range .Kinds -}} diff --git a/genlib2/internaleng.go b/genlib2/internaleng.go index ee5f15c..48a3938 100644 --- a/genlib2/internaleng.go +++ b/genlib2/internaleng.go @@ -10,9 +10,10 @@ import ( type InternalEngArithMethod struct { BinOp - Kinds []reflect.Kind - Incr bool - Iter bool + Kinds []reflect.Kind + Incr bool + Iter bool + WithRecv bool } type eLoopBody struct { @@ -30,6 +31,8 @@ func (fn *InternalEngArithMethod) Name() string { return fmt.Sprintf("%sIncr", fn.BinOp.Name()) case !fn.Incr && fn.Iter: return fmt.Sprintf("%sIter", fn.BinOp.Name()) + case fn.WithRecv: + return fmt.Sprintf("%sRecv", fn.BinOp.Name()) default: return fn.BinOp.Name() } @@ -48,6 +51,9 @@ func (fn *InternalEngArithMethod) Signature() *Signature { case !fn.Iter && fn.Incr: paramNames = []string{"t", "a", "b", "incr"} paramTemplates = []*template.Template{reflectType, arrayType, arrayType, arrayType} + case fn.WithRecv: + paramNames = []string{"t", "a", "b", "recv"} + paramTemplates = []*template.Template{reflectType, arrayType, arrayType, arrayType} default: paramNames = []string{"t", "a", "b"} paramTemplates = []*template.Template{reflectType, arrayType, arrayType} @@ -72,6 +78,8 @@ func (fn *InternalEngArithMethod) WriteBody(w io.Writer) { T = eArithIncr case fn.Iter && !fn.Incr: T = eArithIter + case fn.WithRecv: + T = eArithRecv default: T = eArith } @@ -107,23 +115,36 @@ func generateEArith(f io.Writer, kinds Kinds) { methods = append(methods, meth) } + // write vanilla for _, meth := range methods { meth.Write(f) meth.Incr = true } + // write incr for _, meth := range methods { meth.Write(f) meth.Incr = false meth.Iter = true } + + // write iter for _, meth := range methods { meth.Write(f) meth.Incr = true } + // write iter incr for _, meth := range methods { meth.Write(f) + meth.Incr = false + meth.Iter = false + } + + // write recv + for _, meth := range methods { + meth.WithRecv = true + meth.Write(f) } } @@ -308,6 +329,102 @@ func generateECmp(f io.Writer, kinds Kinds) { } } +/* MIN/MAX BETWEEN */ + +type InternalEngMinMaxBetween struct { + BinOp + Kinds []reflect.Kind + Iter bool +} + +func (fn *InternalEngMinMaxBetween) Name() string { + name := fn.BinOp.Name() + + switch { + case fn.Iter: + return fmt.Sprintf("%sBetweenIter", name) + default: + return name + "Between" + } +} + +func (fn *InternalEngMinMaxBetween) Signature() *Signature { + var paramNames []string + var paramTemplates []*template.Template + + switch { + case fn.Iter: + paramNames = []string{"t", "a", "b", "ait", "bit"} + paramTemplates = []*template.Template{reflectType, arrayType, arrayType, iteratorType, iteratorType} + default: + paramNames = []string{"t", "a", "b"} + paramTemplates = []*template.Template{reflectType, arrayType, arrayType} + } + return &Signature{ + Name: fn.Name(), + NameTemplate: plainName, + ParamNames: paramNames, + ParamTemplates: paramTemplates, + + Err: true, + } +} + +func (fn *InternalEngMinMaxBetween) WriteBody(w io.Writer) { + var T *template.Template + + switch { + case fn.Iter: + T = eMinMaxIter + default: + T = eMinMaxSame + } + + lb := eLoopBody{ + BinOp: fn.BinOp, + Kinds: fn.Kinds, + } + T.Execute(w, lb) +} + +func (fn *InternalEngMinMaxBetween) Write(w io.Writer) { + w.Write([]byte("func (e E) ")) + sig := fn.Signature() + sig.Write(w) + w.Write([]byte("{\n")) + fn.WriteBody(w) + w.Write([]byte("}\n\n")) +} + +func generateEMinMaxBetween(f io.Writer, kinds Kinds) { + minmaxOps := []cmpOp{cmpBinOps[0], cmpBinOps[2]} // Gt and Lt + minmaxOps[0].name = "Max" + minmaxOps[1].name = "Min" + var methods []*InternalEngMinMaxBetween + for _, bo := range minmaxOps { + var ks []reflect.Kind + for _, k := range kinds.Kinds { + if tc := bo.TypeClass(); tc != nil && tc(k) { + ks = append(ks, k) + } + } + meth := &InternalEngMinMaxBetween{ + BinOp: bo, + Kinds: ks, + } + methods = append(methods, meth) + } + + for _, meth := range methods { + meth.Write(f) + meth.Iter = true + } + for _, meth := range methods { + meth.Write(f) + } + +} + /* REDUCTION */ type InternalEngReduce struct { diff --git a/genlib2/main.go b/genlib2/main.go index 7207327..f9923dc 100644 --- a/genlib2/main.go +++ b/genlib2/main.go @@ -6,16 +6,18 @@ import ( "log" "os" "os/exec" + "os/user" "path" "path/filepath" "reflect" + "runtime" "strings" ) const genmsg = "Code generated by genlib2. DO NOT EDIT." var ( - gopath, tensorPkgLoc, execLoc, storageLoc string + gopath, tensorPkgLoc, nativePkgLoc, execLoc, storageLoc string ) type Kinds struct { @@ -24,7 +26,24 @@ type Kinds struct { func init() { gopath = os.Getenv("GOPATH") + + // now that go can have a default gopath, this checks that path + if gopath == "" { + usr, err := user.Current() + if err != nil { + log.Fatal(err) + } + gopath = path.Join(usr.HomeDir, "go") + stat, err := os.Stat(gopath) + if err != nil { + log.Fatal(err) + } + if !stat.IsDir() { + log.Fatal("You need to define a $GOPATH") + } + } tensorPkgLoc = path.Join(gopath, "src/gorgonia.org/tensor") + nativePkgLoc = path.Join(gopath, "src/gorgonia.org/tensor/native") execLoc = path.Join(gopath, "src/gorgonia.org/tensor/internal/execution") storageLoc = path.Join(gopath, "src/gorgonia.org/tensor/internal/storage") } @@ -33,6 +52,7 @@ func main() { pregenerate() // storage + pipeline(storageLoc, "consts.go", Kinds{allKinds}, generateReflectTypes) pipeline(storageLoc, "getset.go", Kinds{allKinds}, generateHeaderGetSet) pipeline(tensorPkgLoc, "array_getset.go", Kinds{allKinds}, generateArrayMethods) @@ -53,6 +73,7 @@ func main() { pipeline(execLoc, "eng_arith.go", Kinds{allKinds}, generateEArith) pipeline(execLoc, "eng_map.go", Kinds{allKinds}, generateEMap) pipeline(execLoc, "eng_cmp.go", Kinds{allKinds}, generateECmp) + pipeline(execLoc, "eng_minmaxbetween.go", Kinds{allKinds}, generateEMinMaxBetween) pipeline(execLoc, "eng_reduce.go", Kinds{allKinds}, generateEReduce) pipeline(execLoc, "eng_unary.go", Kinds{allKinds}, generateUncondEUnary, generateCondEUnary, generateSpecialEUnaries) pipeline(execLoc, "reduction_specialization.go", Kinds{allKinds}, generateReductionSpecialization) @@ -62,6 +83,7 @@ func main() { pipeline(tensorPkgLoc, "defaultengine_arith.go", Kinds{allKinds}, generateStdEngArith) pipeline(tensorPkgLoc, "defaultengine_cmp.go", Kinds{allKinds}, generateStdEngCmp) pipeline(tensorPkgLoc, "defaultengine_unary.go", Kinds{allKinds}, generateStdEngUncondUnary, generateStdEngCondUnary) + pipeline(tensorPkgLoc, "defaultengine_minmax.go", Kinds{allKinds}, generateStdEngMinMax) // level 3 aggregation pipeline(tensorPkgLoc, "dense_arith.go", Kinds{allKinds}, generateDenseArith) @@ -93,10 +115,26 @@ func main() { pipeline(tensorPkgLoc, "api_unary_generated_test.go", Kinds{allKinds}, generateAPIUnaryTests) pipeline(tensorPkgLoc, "api_cmp_generated_test.go", Kinds{allKinds}, generateAPICmpTests, generateAPICmpMixedTests) pipeline(tensorPkgLoc, "dense_cmp_test.go", Kinds{allKinds}, generateDenseMethodCmpTests, generateDenseMethodCmpMixedTests) + + // native iterators - the ones in the tensor package + pipeline(tensorPkgLoc, "iterator_native.go", Kinds{allKinds}, generateNativeIterators(false)) + pipeline(tensorPkgLoc, "iterator_native_test.go", Kinds{allKinds}, generateNativeIteratorTests(false)) + pipeline(tensorPkgLoc, "select_native.go", Kinds{allKinds}, generateNativeSelect(false)) + pipeline(tensorPkgLoc, "select_native_test.go", Kinds{allKinds}, generateNativeSelectTests(false)) + + // native iterators - exported into gorgonia.org/tensor/native + pipeline(nativePkgLoc+"_unsafe", "iterator_native.go", Kinds{allKinds}, generateNativeIteratorStubs) + pipeline(nativePkgLoc+"_purego", "iterator_native_purego.go", Kinds{allKinds}, generateNativeIterators(true)) + pipeline(nativePkgLoc, "iterator_native_test.go", Kinds{allKinds}, generateNativeIteratorTests(true)) + pipeline(nativePkgLoc+"_unsafe", "select_native.go", Kinds{allKinds}, generateNativeSelectStubs) + pipeline(nativePkgLoc+"_purego", "select_native_purego.go", Kinds{allKinds}, generateNativeSelect(true)) + pipeline(nativePkgLoc, "select_native_test.go", Kinds{allKinds}, generateNativeSelectTests(true)) + pipeline(nativePkgLoc, "utils.go", Kinds{allKinds}, generateNativeIterChecks, generateNativeSelChecks) } func pipeline(pkg, filename string, kinds Kinds, fns ...func(io.Writer, Kinds)) { - fullpath := path.Join(pkg, filename) + pkgpath := strings.Replace(strings.Replace(pkg, "_unsafe", "", -1), "_purego", "", -1) + fullpath := path.Join(pkgpath, filename) f, err := os.Create(fullpath) if err != nil { log.Printf("fullpath %q", fullpath) @@ -115,7 +153,12 @@ func pipeline(pkg, filename string, kinds Kinds, fns ...func(io.Writer, Kinds)) log.Fatalf("Go imports failed with %v for %q", err, fullpath) } - cmd = exec.Command("sed", "-i", `s/github.com\/alecthomas\/assert/github.com\/stretchr\/testify\/assert/g`, fullpath) + // account for differences in the postix from the linux sed + if runtime.GOOS == "darwin" || strings.HasSuffix(runtime.GOOS, "bsd") { + cmd = exec.Command("sed", "-i", "", `s/github.com\/alecthomas\/assert/github.com\/stretchr\/testify\/assert/g`, fullpath) + } else { + cmd = exec.Command("sed", "-E", "-i", `s/github.com\/alecthomas\/assert/github.com\/stretchr\/testify\/assert/g`, fullpath) + } if err = cmd.Run(); err != nil { if err.Error() != "exit status 4" { // exit status 4 == not found log.Fatalf("sed failed with %v for %q", err.Error(), fullpath) @@ -136,6 +179,9 @@ func pregenerate() error { if err := cleanup(execLoc); err != nil { return err } + if err := cleanup(nativePkgLoc); err != nil { + return err + } return cleanup(tensorPkgLoc) } diff --git a/genlib2/native_iterator.go b/genlib2/native_iterator.go new file mode 100644 index 0000000..1d7a85c --- /dev/null +++ b/genlib2/native_iterator.go @@ -0,0 +1,270 @@ +package main + +import ( + "fmt" + "io" + "reflect" + "text/template" +) + +const checkNativeiterable = `func checkNativeIterable(t *Dense, dims int, dt dtype.Dtype) error { + // checks: + if !t.IsNativelyAccessible() { + return errors.Errorf("Cannot convert *Dense to *mat.Dense. Data is inaccessible") + } + + if t.Shape().Dims() != dims { + return errors.Errorf("Cannot convert *Dense to native iterator. Expected number of dimension: %d, T has got %d dimensions (Shape: %v)", dims, t.Dims(), t.Shape()) + } + + if t.F() || t.RequiresIterator() { + return errors.Errorf("Not yet implemented: native matrix for colmajor or unpacked matrices") + } + + if t.Dtype() != dt { + return errors.Errorf("Conversion to native iterable only works on %v. Got %v", dt, t.Dtype()) + } + + return nil +} +` + +const nativeIterRaw = ` +{{- $vecName := ( printf "nativeDenseVector%s" (short .K) ) -}} +{{- $matName := ( printf "nativeDenseMatrix%s" (short .K) ) -}} +{{- $T3Name := ( printf "nativeDenseTensor3%s" (short .K) ) -}} +{{- if .N -}} + {{- $vecName = ( printf "Vector%s" (short .K) ) -}} + {{- $matName = ( printf "Matrix%s" (short .K) ) -}} + {{- $T3Name = ( printf "Tensor3%s" (short .K) ) -}} +{{- end -}} + +// {{$vecName}} converts a *Dense into a []{{asType .K}} +// If the *Dense does not represent a vector of the wanted type, it will return +// an error. +func {{$vecName}}(t *Dense) (retVal []{{asType .K}}, err error) { + if err = checkNativeIterable(t, 1, {{reflectKind .K}}); err != nil { + return nil, err + } + return t.{{sliceOf .K}}, nil +} + +// {{$matName}} converts a *Dense into a [][]{{asType .K}} +// If the *Dense does not represent a matrix of the wanted type, it +// will return an error. +func {{$matName}}(t *Dense) (retVal [][]{{asType .K}}, err error) { + if err = checkNativeIterable(t, 2, {{reflectKind .K}}); err != nil { + return nil, err + } + + data := t.{{sliceOf .K}} + shape := t.Shape() + strides := t.Strides() + + rows := shape[0] + cols := shape[1] + rowStride := strides[0] + retVal = make([][]{{asType .K}}, rows) + for i := range retVal { + start := i * rowStride + retVal[i] = make([]{{asType .K}}, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + return +} + +// {{$T3Name}} converts a *Dense into a [][][]{{asType .K}}. +// If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. +func {{$T3Name}}(t *Dense) (retVal [][][]{{asType .K}}, err error) { + if err = checkNativeIterable(t, 3, {{reflectKind .K}}); err != nil { + return nil, err + } + + data := t.{{sliceOf .K}} + shape := t.Shape() + strides := t.Strides() + + layers := shape[0] + rows := shape[1] + cols := shape[2] + layerStride := strides[0] + rowStride := strides[1] + retVal = make([][][]{{asType .K}}, layers) + for i := range retVal { + retVal[i] = make([][]{{asType .K}}, rows) + for j := range retVal[i] { + retVal[i][j] = make([]{{asType .K}}, 0) + start := i*layerStride + j*rowStride + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + } + return +} +` + +const nativeIterStubsRaw = `//go:linkname Vector{{short .}} gorgonia.org/tensor.nativeDenseVector{{short .}} + +// Vector{{short .}} converts a *Dense into a []{{asType .}} +// If the *Dense does not represent a vector of the wanted type, it will return +// an error. +func Vector{{short .}}(t *tensor.Dense) (retVal []{{asType .}}, err error) + +//go:linkname Matrix{{short .}} gorgonia.org/tensor.nativeDenseMatrix{{short .}} + +// Matrix{{short .}} converts a *Dense into a [][]{{asType .}} +// If the *Dense does not represent a matrix of the wanted type, it +// will return an error. +func Matrix{{short .}}(t *tensor.Dense) (retVal [][]{{asType .}}, err error) + +//go:linkname Tensor3{{short .}} gorgonia.org/tensor.nativeDenseTensor3{{short .}} + +// Tensor3{{short .}} converts a *Dense into a [][][]{{asType .}}. +// If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. +func Tensor3{{short .}}(t *tensor.Dense) (retVal [][][]{{asType .}}, err error) +` + +const nativeIterTestRaw = ` +{{- $pkgTVecName := ( printf "nativeDenseVector%s" (short .K) ) -}} +{{- $pkgTMatName := ( printf "nativeDenseMatrix%s" (short .K) ) -}} +{{- $pkgTT3Name := ( printf "nativeDenseTensor3%s" (short .K) ) -}} +{{- $pkgNVecName := ( printf "Vector%s" (short .K) ) -}} +{{- $pkgNMatName := ( printf "Matrix%s" (short .K) ) -}} +{{- $pkgNT3Name := ( printf "Tensor3%s" (short .K) ) -}} +{{- $vecName := "" -}} +{{- $matName := "" -}} +{{- $T3Name := "" -}} +{{- if .N -}} + {{- $vecName = $pkgNVecName -}} + {{- $matName = $pkgNMatName -}} + {{- $T3Name = $pkgNT3Name -}} +{{- else -}} + {{- $vecName = $pkgTVecName -}} + {{- $matName = $pkgTMatName -}} + {{- $T3Name = $pkgTT3Name -}} +{{end -}} + + +func Test_{{$vecName}}(t *testing.T) { + assert := assert.New(t) + var T *Dense + {{if isRangeable .K -}} + T = New(WithBacking(Range({{reflectKind .K}}, 0, 6)), WithShape(6)) + {{else -}} + T = New(Of({{reflectKind .K}}), WithShape(6)) + {{end -}} + it, err := {{$vecName}}(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(6, len(it)) +} + +func Test_{{$matName}}(t *testing.T) { + assert := assert.New(t) + var T *Dense + {{if isRangeable .K -}} + T = New(WithBacking(Range({{reflectKind .K}}, 0, 6)), WithShape(2, 3)) + {{else -}} + T = New(Of({{reflectKind .K}}), WithShape(2, 3)) + {{end -}} + it, err := {{$matName}}(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) +} + +func Test_{{$T3Name}}(t *testing.T) { + assert := assert.New(t) + var T *Dense + {{if isRangeable .K -}} + T = New(WithBacking(Range({{reflectKind .K}}, 0, 24)), WithShape(2, 3, 4)) + {{else -}} + T = New(Of({{reflectKind .K}}), WithShape(2, 3, 4)) + {{end -}} + it, err := {{$T3Name}}(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) + assert.Equal(4, len(it[0][0])) +} +` + +var ( + NativeIter *template.Template + NativeIterTest *template.Template + NativeIterStubs *template.Template +) + +func init() { + NativeIter = template.Must(template.New("NativeIter").Funcs(funcs).Parse(nativeIterRaw)) + NativeIterTest = template.Must(template.New("NativeIterTest").Funcs(funcs).Parse(nativeIterTestRaw)) + NativeIterStubs = template.Must(template.New("NativeStubs").Funcs(funcs).Parse(nativeIterStubsRaw)) +} + +// generateNativeIterators generates the code for native iterators. `isNative` represents whether the code is generated for the `native` package or not. +// isNative will only be true for the `purego` build tag. +func generateNativeIterators(isNative bool) func(f io.Writer, ak Kinds) { + type IterTup struct { + N bool + K reflect.Kind + } + return func(f io.Writer, ak Kinds) { + if isNative { + // checkNativeIteratble is separately generated and placed into util.go in the `native` package + // so there is no need to generate that here. + fmt.Fprintf(f, importUnqualifiedTensor) + } else { + fmt.Fprintf(f, "%v\n", checkNativeiterable) + } + ks := filter(ak.Kinds, isSpecialized) + for _, k := range ks { + fmt.Fprintf(f, "/* Native Iterables for %v */\n\n", k) + NativeIter.Execute(f, IterTup{N: isNative, K: k}) + fmt.Fprint(f, "\n\n") + } + } +} + +func generateNativeIteratorTests(isNative bool) func(f io.Writer, ak Kinds) { + type IterTup struct { + N bool + K reflect.Kind + } + return func(f io.Writer, ak Kinds) { + if isNative { + fmt.Fprintf(f, importUnqualifiedTensor) + } + ks := filter(ak.Kinds, isSpecialized) + for _, k := range ks { + NativeIterTest.Execute(f, IterTup{N: isNative, K: k}) + fmt.Fprint(f, "\n\n") + } + } +} + +func generateNativeIteratorStubs(f io.Writer, ak Kinds) { + fmt.Fprintf(f, importUnsafe) // this is required for go:linkname to work + ks := filter(ak.Kinds, isSpecialized) + for _, k := range ks { + NativeIterStubs.Execute(f, k) + fmt.Fprint(f, "\n\n") + } +} + +func generateNativeIterChecks(f io.Writer, ak Kinds) { + fmt.Fprintf(f, importUnqualifiedTensor) + fmt.Fprintf(f, "%v\n", checkNativeiterable) +} diff --git a/genlib2/native_select.go b/genlib2/native_select.go new file mode 100644 index 0000000..1095668 --- /dev/null +++ b/genlib2/native_select.go @@ -0,0 +1,196 @@ +package main + +import ( + "fmt" + "io" + "reflect" + "text/template" +) + +const checkNativeSelectable = `func checkNativeSelectable(t *Dense, axis int, dt dtype.Dtype) error { + if !t.IsNativelyAccessible() { + return errors.New("Cannot select on non-natively accessible data") + } + if axis >= t.Shape().Dims() && !(t.IsScalar() && axis == 0) { + return errors.Errorf("Cannot select on axis %d. Shape is %v", axis, t.Shape()) + } + if t.F() || t.RequiresIterator() { + return errors.Errorf("Not yet implemented: native select for colmajor or unpacked matrices") + } + if t.Dtype() != dt { + return errors.Errorf("Native selection only works on %v. Got %v", dt, t.Dtype()) + } + return nil +} +` +const nativeSelectRaw = ` +{{- $selName := ( printf "nativeSelect%s" (short .K) ) -}} +{{- if .N -}} + {{- $selName = ( printf "Select%s" (short .K) ) -}} +{{- end -}} + +// {{$selName}} creates a slice of flat data types. See Example of NativeSelectF64. +func {{$selName}}(t *Dense, axis int) (retVal [][]{{asType .K}}, err error) { + if err := checkNativeSelectable(t, axis, {{reflectKind .K}}); err != nil { + return nil, err + } + + switch t.Shape().Dims() { + case 0, 1: + retVal = make([][]{{asType .K}}, 1) + retVal[0] = t.{{sliceOf .K}} + case 2: + if axis == 0 { + return {{if .N}}Matrix{{short .K}}{{else}}nativeDenseMatrix{{short .K}}{{end}}(t) + } + fallthrough + default: + // size := t.Shape()[axis] + data := t.{{sliceOf .K}} + stride := t.Strides()[axis] + upper := ProdInts(t.Shape()[:axis+1]) + retVal = make([][]{{asType .K}}, 0, upper) + for i, r := 0, 0; r < upper; i += stride { + s := make([]{{asType .K}}, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&s)) + hdr.Data = uintptr(unsafe.Pointer(&data[i])) + hdr.Len = stride + hdr.Cap = stride + retVal = append(retVal, s) + r++ + } + return retVal, nil + + } + return +} +` +const nativeSelectTestRaw = ` +{{- $selName := ( printf "nativeSelect%s" (short .K) ) -}} +{{- if .N -}} + {{- $selName = ( printf "Select%s" (short .K) ) -}} +{{- end -}} +func Test{{$selName}}(t *testing.T) { + assert := assert.New(t) + var T *Dense + var err error + var x [][]{{asType .K}} + T = New(Of({{reflectKind .K}}), WithShape(2, 3, 4, 5), ) + if x, err = {{$selName}}(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(20, len(x[0])) + + T = New(Of({{reflectKind .K}}), WithShape(2, 3, 4, 5), ) + if x, err = {{$selName}}(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(60, len(x[0])) + + T = New(Of({{reflectKind .K}}), WithShape(2, 3, 4, 5), ) + if x, err = {{$selName}}(T, 3); err != nil { + t.Fatal(err) + } + assert.Equal(120, len(x)) + assert.Equal(1, len(x[0])) + + T = New(Of({{reflectKind .K}}), WithShape(2, 3), ) + if x, err = {{$selName}}(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(3, len(x[0])) + + T = New(Of({{reflectKind .K}}), WithShape(2, 3), ) + if x, err = {{$selName}}(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(1, len(x[0])) + + T = New(FromScalar({{if eq .K.String "bool" -}}false{{else if eq .K.String "string" -}}""{{else -}}{{asType .K}}(0) {{end -}} )) + if x, err = {{$selName}}(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(1, len(x)) + assert.Equal(1, len(x[0])) + + if _, err = {{$selName}}(T, 10); err == nil{ + t.Fatal("Expected errors") + } +} +` + +const nativeSelectStubsRaw = `//go:linkname Select{{short .}} gorgonia.org/tensor.nativeSelect{{short .}} + +// Select{{short .}} creates a slice of {{asType .}}s. See Example of NativeSelectF64. +func Select{{short .}}(t *tensor.Dense, axis int) (retVal [][]{{asType .}}, err error) +` + +var ( + NativeSelect *template.Template + NativeSelectTest *template.Template + NativeSelectStubs *template.Template +) + +func init() { + NativeSelect = template.Must(template.New("NativeSelect").Funcs(funcs).Parse(nativeSelectRaw)) + NativeSelectTest = template.Must(template.New("NativeSelectTest").Funcs(funcs).Parse(nativeSelectTestRaw)) + NativeSelectStubs = template.Must(template.New("NativeSelectStub").Funcs(funcs).Parse(nativeSelectStubsRaw)) +} + +// generateNativeSelect generates code for the native selection. `isNative` indicates if the +// code is meant to be generated for the native package. The code is generated for the native package +// only for the purposes of the `purego` build tag. +func generateNativeSelect(isNative bool) func(io.Writer, Kinds) { + type IterTup struct { + N bool + K reflect.Kind + } + return func(f io.Writer, ak Kinds) { + if isNative { + fmt.Fprintf(f, importUnqualifiedTensor) + } else { + fmt.Fprintf(f, "%v\n", checkNativeSelectable) + } + ks := filter(ak.Kinds, isSpecialized) + for _, k := range ks { + fmt.Fprintf(f, "/* Native Select for %v */\n\n", k) + NativeSelect.Execute(f, IterTup{N: isNative, K: k}) + fmt.Fprint(f, "\n\n") + } + } +} + +func generateNativeSelectTests(isNative bool) func(f io.Writer, ak Kinds) { + type IterTup struct { + N bool + K reflect.Kind + } + return func(f io.Writer, ak Kinds) { + if isNative { + fmt.Fprintf(f, importUnqualifiedTensor) + } + ks := filter(ak.Kinds, isSpecialized) + for _, k := range ks { + NativeSelectTest.Execute(f, IterTup{N: isNative, K: k}) + fmt.Fprint(f, "\n\n") + } + } +} + +func generateNativeSelectStubs(f io.Writer, ak Kinds) { + fmt.Fprintf(f, importUnsafe) // this is required for go:linkname to work + ks := filter(ak.Kinds, isSpecialized) + for _, k := range ks { + NativeSelectStubs.Execute(f, k) + fmt.Fprintf(f, "\n\n") + } +} + +func generateNativeSelChecks(f io.Writer, ak Kinds) { + // fmt.Fprintf(f, importUnqualifiedTensor) // already generated by generateNativeIterChecks + fmt.Fprintf(f, "%v\n", checkNativeSelectable) +} diff --git a/genlib2/package.go b/genlib2/package.go index e78e7c0..8ffcf79 100644 --- a/genlib2/package.go +++ b/genlib2/package.go @@ -8,12 +8,27 @@ import ( func writePkgName(f io.Writer, pkg string) { switch pkg { case tensorPkgLoc: - fmt.Fprintf(f, "// %s\n\npackage tensor\n\n", genmsg) + fmt.Fprintf(f, "package tensor\n\n // %s\n\n", genmsg) + case nativePkgLoc: + fmt.Fprintf(f, "package native\n\n // %s\n\n", genmsg) + case nativePkgLoc + "_unsafe": + fmt.Fprintf(f, "// +build !purego \n\npackage native\n\n // %s\n\n", genmsg) + case nativePkgLoc + "_purego": + fmt.Fprintf(f, "// +build purego \n\npackage native\n\n // %s\n\n", genmsg) case execLoc: - fmt.Fprintf(f, "// %s\n\npackage execution\n\n", genmsg) + fmt.Fprintf(f, "package execution\n\n // %s\n\n", genmsg) case storageLoc: - fmt.Fprintf(f, "// %s\n\npackage storage\n\n", genmsg) + fmt.Fprintf(f, "package storage\n\n // %s\n\n", genmsg) default: - fmt.Fprintf(f, "// %s\n\npackage unknown\n\n", genmsg) + fmt.Fprintf(f, "package unknown\n\n %s\n\n", genmsg) } } + +const importUnqualifiedTensor = `import . "gorgonia.org/tensor" +` + +const importInternalNative = `import inative "gorgonia.org/tensor/internal/native" +` + +const importUnsafe = `import _ "unsafe" +` diff --git a/genlib2/testutils.go b/genlib2/testutils.go index 177333f..c7dbe81 100644 --- a/genlib2/testutils.go +++ b/genlib2/testutils.go @@ -90,7 +90,7 @@ const qcGenraw = `func randomQC(a Tensor, r *rand.Rand) { s[i] = randomString() {{else if eq .String "unsafe.Pointer" -}} s[i] = nil - {{end -}} + {{end -}} } {{end -}} {{end -}} @@ -99,7 +99,7 @@ const qcGenraw = `func randomQC(a Tensor, r *rand.Rand) { ` const testQCRaw = `type QCDense{{short .}} struct { - *Dense + *Dense } func (*QCDense{{short .}}) Generate(r *rand.Rand, size int) reflect.Value { s := make([]{{asType .}}, size) @@ -137,11 +137,11 @@ const mutateFnsRaw = `func mutate{{short .}}(a {{asType . }}){{asType .}} { {{if {{else if eq .String "bool" -}}return true } {{else if eq .String "string" -}}return "Hello World"} {{else if eq .String "uintptr" -}}return 0xdeadbeef} -{{else if eq .String "unsafe.Pointer" -}}return unsafe.Pointer(uintptr(0xdeadbeef))} -{{end -}} +{{else if eq .String "unsafe.Pointer" -}}return unsafe.Pointer(uintptr(0xdeadbeef))} +{{end -}} ` -const identityValsRaw = `func identityVal(x int, dt Dtype) interface{} { +const identityValsRaw = `func identityVal(x int, dt dtype.Dtype) interface{} { switch dt { {{range .Kinds -}} case {{reflectKind .}}: diff --git a/genlib2/unary_tests.go b/genlib2/unary_tests.go index dedd02d..5153f2b 100644 --- a/genlib2/unary_tests.go +++ b/genlib2/unary_tests.go @@ -1,151 +1,151 @@ -package main - -import ( - "fmt" - "io" - "text/template" -) - -const unaryTestBodyRaw = `invFn := func(q *Dense) bool { - a := q.Clone().(*Dense) - {{template "funcoptdecl" -}} - correct := a.Clone().(*Dense) - {{template "funcoptcorrect" -}} - - - we, willFailEq := willerr(a, {{.TypeClassName}}, {{.EqFailTypeClassName}}) - _, ok := q.Engine().({{interfaceName .Name}}); we = we || !ok - - ret, err := {{.Name}}(a {{template "funcoptuse"}}) - if err, retEarly := qcErrCheck(t, "{{.Name}}", a, nil, we, err); retEarly{ - if err != nil { - return false - } - return true - } - {{if ne .InvTypeClass "" -}} - if err := typeclassCheck(a.Dtype(), {{.InvTypeClass}}); err != nil { - return true // uninvertible due to type class implementation issues - } - {{end -}} - {{if eq .FuncOpt "incr" -}} - if ret, err = Sub(ret, identityVal(100, a.Dtype()), UseUnsafe()) ; err != nil { - t.Errorf("err while subtracting incr: %v", err) - return false - } - {{end -}} - {{.Inv}}(ret, UseUnsafe()) - if !qcEqCheck(t, a.Dtype(), willFailEq, correct.Data(), ret.Data()) { - return false - } - {{template "funcoptcheck" -}} - return true -} - -if err := quick.Check(invFn, &quick.Config{Rand:newRand(), MaxCount: quickchecks}); err != nil{ - t.Errorf("Inv tests for {{.Name}} failed: %v", err) -} -` - -type unaryTest struct { - unaryOp - FuncOpt string - EqFailTypeClassName string - InvTypeClass string -} - -func (fn *unaryTest) Name() string { - if fn.unaryOp.Name() == "Eq" || fn.unaryOp.Name() == "Ne" { - return "El" + fn.unaryOp.Name() - } - return fn.unaryOp.Name() -} - -func (fn *unaryTest) Signature() *Signature { - name := fmt.Sprintf("Test%s", fn.unaryOp.Name()) - if fn.FuncOpt != "" { - name += "_" + fn.FuncOpt - } - return &Signature{ - Name: name, - NameTemplate: plainName, - ParamNames: []string{"t"}, - ParamTemplates: []*template.Template{testingType}, - } -} - -func (fn *unaryTest) WriteBody(w io.Writer) { - t := template.Must(template.New("unary test body").Funcs(funcs).Parse(unaryTestBodyRaw)) - template.Must(t.New("funcoptdecl").Parse(funcOptDecl[fn.FuncOpt])) - template.Must(t.New("funcoptcorrect").Parse(funcOptCorrect[fn.FuncOpt])) - template.Must(t.New("funcoptuse").Parse(funcOptUse[fn.FuncOpt])) - template.Must(t.New("funcoptcheck").Parse(funcOptCheck[fn.FuncOpt])) - t.Execute(w, fn) -} - -func (fn *unaryTest) canWrite() bool { return fn.Inv != "" } - -func (fn *unaryTest) Write(w io.Writer) { - sig := fn.Signature() - w.Write([]byte("func ")) - sig.Write(w) - w.Write([]byte("{\n")) - fn.WriteBody(w) - w.Write([]byte("}\n")) -} - -func generateAPIUnaryTests(f io.Writer, ak Kinds) { - var tests []*unaryTest - for _, op := range conditionalUnaries { - t := &unaryTest{ - unaryOp: op, - EqFailTypeClassName: "nil", - } - - tests = append(tests, t) - } - - for _, op := range unconditionalUnaries { - t := &unaryTest{ - unaryOp: op, - EqFailTypeClassName: "nil", - } - switch op.name { - case "Square": - t.InvTypeClass = "floatcmplxTypes" - case "Cube": - t.InvTypeClass = "floatTypes" - } - - tests = append(tests, t) - } - - for _, fn := range tests { - if fn.canWrite() { - fn.Write(f) - } - fn.FuncOpt = "unsafe" - } - - for _, fn := range tests { - if fn.canWrite() { - fn.Write(f) - } - fn.FuncOpt = "reuse" - } - - for _, fn := range tests { - if fn.canWrite() { - fn.Write(f) - } - fn.FuncOpt = "incr" - } - - // for now incr cannot be quickchecked - - for _, fn := range tests { - if fn.canWrite() { - fn.Write(f) - } - } -} +package main + +import ( + "fmt" + "io" + "text/template" +) + +const unaryTestBodyRaw = `invFn := func(q *Dense) bool { + a := q.Clone().(*Dense) + {{template "funcoptdecl" -}} + correct := a.Clone().(*Dense) + {{template "funcoptcorrect" -}} + + + we, willFailEq := willerr(a, {{.TypeClassName}}, {{.EqFailTypeClassName}}) + _, ok := q.Engine().({{interfaceName .Name}}); we = we || !ok + + ret, err := {{.Name}}(a {{template "funcoptuse"}}) + if err, retEarly := qcErrCheck(t, "{{.Name}}", a, nil, we, err); retEarly{ + if err != nil { + return false + } + return true + } + {{if ne .InvTypeClass "" -}} + if err := dtype.TypeClassCheck(a.Dtype(), {{.InvTypeClass}}); err != nil { + return true // uninvertible due to type class implementation issues + } + {{end -}} + {{if eq .FuncOpt "incr" -}} + if ret, err = Sub(ret, identityVal(100, a.Dtype()), UseUnsafe()) ; err != nil { + t.Errorf("err while subtracting incr: %v", err) + return false + } + {{end -}} + {{.Inv}}(ret, UseUnsafe()) + if !qcEqCheck(t, a.Dtype(), willFailEq, correct.Data(), ret.Data()) { + return false + } + {{template "funcoptcheck" -}} + return true +} + +if err := quick.Check(invFn, &quick.Config{Rand:newRand(), MaxCount: quickchecks}); err != nil{ + t.Errorf("Inv tests for {{.Name}} failed: %v", err) +} +` + +type unaryTest struct { + unaryOp + FuncOpt string + EqFailTypeClassName string + InvTypeClass string +} + +func (fn *unaryTest) Name() string { + if fn.unaryOp.Name() == "Eq" || fn.unaryOp.Name() == "Ne" { + return "El" + fn.unaryOp.Name() + } + return fn.unaryOp.Name() +} + +func (fn *unaryTest) Signature() *Signature { + name := fmt.Sprintf("Test%s", fn.unaryOp.Name()) + if fn.FuncOpt != "" { + name += "_" + fn.FuncOpt + } + return &Signature{ + Name: name, + NameTemplate: plainName, + ParamNames: []string{"t"}, + ParamTemplates: []*template.Template{testingType}, + } +} + +func (fn *unaryTest) WriteBody(w io.Writer) { + t := template.Must(template.New("unary test body").Funcs(funcs).Parse(unaryTestBodyRaw)) + template.Must(t.New("funcoptdecl").Parse(funcOptDecl[fn.FuncOpt])) + template.Must(t.New("funcoptcorrect").Parse(funcOptCorrect[fn.FuncOpt])) + template.Must(t.New("funcoptuse").Parse(funcOptUse[fn.FuncOpt])) + template.Must(t.New("funcoptcheck").Parse(funcOptCheck[fn.FuncOpt])) + t.Execute(w, fn) +} + +func (fn *unaryTest) canWrite() bool { return fn.Inv != "" } + +func (fn *unaryTest) Write(w io.Writer) { + sig := fn.Signature() + w.Write([]byte("func ")) + sig.Write(w) + w.Write([]byte("{\n")) + fn.WriteBody(w) + w.Write([]byte("}\n")) +} + +func generateAPIUnaryTests(f io.Writer, ak Kinds) { + var tests []*unaryTest + for _, op := range conditionalUnaries { + t := &unaryTest{ + unaryOp: op, + EqFailTypeClassName: "nilTC", + } + + tests = append(tests, t) + } + + for _, op := range unconditionalUnaries { + t := &unaryTest{ + unaryOp: op, + EqFailTypeClassName: "nilTC", + } + switch op.name { + case "Square": + t.InvTypeClass = "dtype.FloatComplex" + case "Cube": + t.InvTypeClass = "dtype.Floats" + } + + tests = append(tests, t) + } + + for _, fn := range tests { + if fn.canWrite() { + fn.Write(f) + } + fn.FuncOpt = "unsafe" + } + + for _, fn := range tests { + if fn.canWrite() { + fn.Write(f) + } + fn.FuncOpt = "reuse" + } + + for _, fn := range tests { + if fn.canWrite() { + fn.Write(f) + } + fn.FuncOpt = "incr" + } + + // for now incr cannot be quickchecked + + for _, fn := range tests { + if fn.canWrite() { + fn.Write(f) + } + } +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..bcb9359 --- /dev/null +++ b/go.mod @@ -0,0 +1,34 @@ +module gorgonia.org/tensor + +go 1.18 + +replace gorgonia.org/dtype => /home/chewxy/workspace/gorgoniaws/src/gorgonia.org/dtype + +replace gorgonia.org/shapes => /home/chewxy/workspace/gorgoniaws/src/gorgonia.org/shapes + +require ( + github.com/apache/arrow/go/arrow v0.0.0-20201229220542-30ce2eb5d4dc + github.com/chewxy/hm v1.0.0 // indirect + github.com/chewxy/math32 v1.0.8 + github.com/gogo/protobuf v1.3.2 + github.com/golang/protobuf v1.4.3 + github.com/google/flatbuffers v1.12.0 + github.com/pkg/errors v0.9.1 + github.com/stretchr/testify v1.7.0 + go4.org/unsafe/assume-no-moving-gc v0.0.0-20220617031537-928513b29760 + gonum.org/v1/gonum v0.8.2 + gorgonia.org/dtype v0.0.0-00010101000000-000000000000 + gorgonia.org/shapes v0.0.0-00010101000000-000000000000 + gorgonia.org/vecf32 v0.9.0 + gorgonia.org/vecf64 v0.9.0 +) + +require ( + github.com/davecgh/go-spew v1.1.0 // indirect + github.com/google/gofuzz v1.2.0 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/xtgo/set v1.0.0 // indirect + golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect + google.golang.org/protobuf v1.25.0 // indirect + gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..60fa6ad --- /dev/null +++ b/go.sum @@ -0,0 +1,164 @@ +cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= +github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= +github.com/ajstarks/svgo v0.0.0-20180226025133-644b8db467af/go.mod h1:K08gAheRH3/J6wwsYMMT4xOr94bZjxIelGM0+d/wbFw= +github.com/apache/arrow/go/arrow v0.0.0-20201229220542-30ce2eb5d4dc h1:zvQ6w7KwtQWgMQiewOF9tFtundRMVZFSAksNV6ogzuY= +github.com/apache/arrow/go/arrow v0.0.0-20201229220542-30ce2eb5d4dc/go.mod h1:c9sxoIT3YgLxH4UhLOCKaBlEojuMhVYpk4Ntv3opUTQ= +github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= +github.com/chewxy/hm v1.0.0 h1:zy/TSv3LV2nD3dwUEQL2VhXeoXbb9QkpmdRAVUFiA6k= +github.com/chewxy/hm v1.0.0/go.mod h1:qg9YI4q6Fkj/whwHR1D+bOGeF7SniIP40VweVepLjg0= +github.com/chewxy/math32 v1.0.0/go.mod h1:Miac6hA1ohdDUTagnvJy/q+aNnEk16qWUdb8ZVhvCN0= +github.com/chewxy/math32 v1.0.8 h1:fU5E4Ec4Z+5RtRAi3TovSxUjQPkgRh+HbP7tKB2OFbM= +github.com/chewxy/math32 v1.0.8/go.mod h1:dOB2rcuFrCn6UHrze36WSLVPKtzPMRAQvBvUwkSsLqs= +github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= +github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc= +github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= +github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= +github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98= +github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= +github.com/fogleman/gg v1.2.1-0.20190220221249-0403632d5b90/go.mod h1:R/bRT+9gY/C5z7JzPU0zXsXHKM4/ayA+zqcVNZzPa1k= +github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= +github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= +github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0/go.mod h1:E/TSTwGwJL78qG/PmXZO1EjYhfJinVAhrmmHX6Z8B9k= +github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= +github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= +github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw= +github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8= +github.com/golang/protobuf v1.4.0-rc.1.0.20200221234624-67d41d38c208/go.mod h1:xKAWHe0F5eneWXFV3EuXVDTCmh+JuBKY0li0aMyXATA= +github.com/golang/protobuf v1.4.0-rc.2/go.mod h1:LlEzMj4AhA7rCAGe4KMBDvJI+AwstrUpVNzEA03Pprs= +github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0/go.mod h1:WU3c8KckQ9AFe+yFwt9sWVRKCVIyN9cPHBJSNnbL67w= +github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0= +github.com/golang/protobuf v1.4.1/go.mod h1:U8fpvMrcmy5pZrNK1lt4xCsGvpyWQ/VVv6QDs8UjoX8= +github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= +github.com/golang/protobuf v1.4.3 h1:JjCZWpVbqXDqFVmTfYWEVTMIYrL/NPdPSCHPJ0T/raM= +github.com/golang/protobuf v1.4.3/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= +github.com/google/flatbuffers v1.11.0/go.mod h1:1AeVuKshWv4vARoZatz6mlQ0JxURH0Kv5+zNeJKJCa8= +github.com/google/flatbuffers v1.12.0 h1:/PtAHvnBY4Kqnx/xCQ3OIV9uYcSFGScBsWI3Oogeh6w= +github.com/google/flatbuffers v1.12.0/go.mod h1:1AeVuKshWv4vARoZatz6mlQ0JxURH0Kv5+zNeJKJCa8= +github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= +github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= +github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= +github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.0 h1:/QaMHBdZ26BB3SSst0Iwl10Epc+xhTquomWX0oZEB6w= +github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/gofuzz v1.2.0 h1:xRy4A+RhZaiKjJ1bPfwQ8sedCA+YS2YcCHW6ec7JMi0= +github.com/google/gofuzz v1.2.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= +github.com/jung-kurt/gofpdf v1.0.3-0.20190309125859-24315acbbda5/go.mod h1:7Id9E/uU8ce6rXgefFLlgrJj/GYY22cpxn+r32jIOes= +github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8= +github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.1.4/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= +github.com/stretchr/testify v1.2.0/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= +github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/xtgo/set v1.0.0 h1:6BCNBRv3ORNDQ7fyoJXRv+tstJz3m1JVFQErfeZz2pY= +github.com/xtgo/set v1.0.0/go.mod h1:d3NHzGzSa0NmB2NhFyECA+QdRp29oEn2xbT+TpeFoM8= +github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= +github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= +go4.org/unsafe/assume-no-moving-gc v0.0.0-20220617031537-928513b29760 h1:FyBZqvoA/jbNzuAWLQE2kG820zMAkcilx6BMjGbL/E4= +go4.org/unsafe/assume-no-moving-gc v0.0.0-20220617031537-928513b29760/go.mod h1:FftLjUGFEDu5k8lt0ddY+HcrH/qU/0qk+H8j9/nTl3E= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= +golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= +golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= +golang.org/x/exp v0.0.0-20190125153040-c74c464bbbf2 h1:y102fOLFqhV41b+4GPiJoa0k/x+pJcEi2/HB1Y5T6fU= +golang.org/x/exp v0.0.0-20190125153040-c74c464bbbf2/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= +golang.org/x/image v0.0.0-20180708004352-c73c2afc3b81/go.mod h1:ux5Hcp/YLpHSI86hEcLt0YII63i6oz57MZXIpbrjZUs= +golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= +golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= +golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= +golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= +golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= +golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20200904194848-62affa334b73/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= +golang.org/x/net v0.0.0-20201021035429-f5854403a974 h1:IX6qOQeG5uLjB/hjjwjedwfjND0hgjPMMyO1RoIXQNI= +golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= +golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= +golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200909081042-eff7692f9009/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f h1:+Nyd8tzPX9R7BWHguqsrbFdRx3WQ/1ib8I44HXV5yTA= +golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.3 h1:cokOdA+Jmi5PJGXLlLllQSgYigAEfHXJAERHVMaCc2k= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/tools v0.0.0-20180525024113-a5b4c53f6e8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20190206041539-40960b6deb8e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= +golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= +golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= +golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= +golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +gonum.org/v1/gonum v0.0.0-20180816165407-929014505bf4/go.mod h1:Y+Yx5eoAFn32cQvJDxZx5Dpnq+c3wtXuadVZAcxbbBo= +gonum.org/v1/gonum v0.8.2 h1:CCXrcPKiGGotvnN6jfUsKk4rRqm7q09/YbKb5xCEvtM= +gonum.org/v1/gonum v0.8.2/go.mod h1:oe/vMfY3deqTw+1EZJhuvEW2iwGF1bW9wwu7XCu0+v0= +gonum.org/v1/netlib v0.0.0-20190313105609-8cb42192e0e0 h1:OE9mWmgKkjJyEmDAAtGMPjXu+YNeGvK9VTSHY6+Qihc= +gonum.org/v1/netlib v0.0.0-20190313105609-8cb42192e0e0/go.mod h1:wa6Ws7BG/ESfp6dHfk7C6KdzKA7wR7u/rKwOGE66zvw= +gonum.org/v1/plot v0.0.0-20190515093506-e2840ee46a6b/go.mod h1:Wt8AAjI+ypCyYX3nZBvf6cAIx93T+c/OS2HFAYskSZc= +google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= +google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= +google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= +google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc= +google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013/go.mod h1:NbSheEEYHJ7i3ixzK3sjbqSGDJWnxyFXZblF3eUsNvo= +google.golang.org/genproto v0.0.0-20200911024640-645f7a48b24f h1:Yv4xsIx7HZOoyUGSJ2ksDyWE2qIBXROsZKt2ny3hCGM= +google.golang.org/genproto v0.0.0-20200911024640-645f7a48b24f/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= +google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= +google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg= +google.golang.org/grpc v1.25.1/go.mod h1:c3i+UQWmh7LiEpx4sFZnkU36qjEYZ0imhYfXVyQciAY= +google.golang.org/grpc v1.27.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk= +google.golang.org/grpc v1.32.0 h1:zWTV+LMdc3kaiJMSTOFz2UgSBgx8RNQoTGiZu3fR9S0= +google.golang.org/grpc v1.32.0/go.mod h1:N36X2cJ7JwdamYAgDz+s+rVMFjt3numwzf/HckM8pak= +google.golang.org/grpc/cmd/protoc-gen-go-grpc v0.0.0-20200910201057-6591123024b3/go.mod h1:6Kw0yEErY5E/yWrBtf03jp27GLLJujG4z/JK95pnjjw= +google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= +google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= +google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= +google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miEFZTKqfCUM6K7xSMQL9OKL/b6hQv+e19PK+JZNE= +google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo= +google.golang.org/protobuf v1.22.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= +google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= +google.golang.org/protobuf v1.23.1-0.20200526195155-81db48ad09cc/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= +google.golang.org/protobuf v1.24.0/go.mod h1:r/3tXBNzIEhYS9I1OUVjXDlt8tc493IdKGjtUeSXeh4= +google.golang.org/protobuf v1.25.0 h1:Ejskq+SyPohKW+1uil0JJMtmHCgJPJ/qWTxr8qp+R4c= +google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlbajtzgsN7c= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gorgonia.org/vecf32 v0.9.0 h1:PClazic1r+JVJ1dEzRXgeiVl4g1/Hf/w+wUSqnco1Xg= +gorgonia.org/vecf32 v0.9.0/go.mod h1:NCc+5D2oxddRL11hd+pCB1PEyXWOyiQxfZ/1wwhOXCA= +gorgonia.org/vecf64 v0.9.0 h1:bgZDP5x0OzBF64PjMGC3EvTdOoMEcmfAh1VCUnZFm1A= +gorgonia.org/vecf64 v0.9.0/go.mod h1:hp7IOWCnRiVQKON73kkC/AUMtEXyf9kGlVrtPQ9ccVA= +honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= +honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= +rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4= diff --git a/interfaces.go b/interfaces.go index 40be33d..7061997 100644 --- a/interfaces.go +++ b/interfaces.go @@ -3,12 +3,13 @@ package tensor import ( "reflect" + "gorgonia.org/dtype" "gorgonia.org/tensor/internal/storage" ) // Dtyper is any type that has a Dtype type Dtyper interface { - Dtype() Dtype + Dtype() dtype.Dtype } // Eq is any type where you can perform an equality test @@ -71,12 +72,27 @@ type Slicer interface { Slice(...Slice) (View, error) } +// SlicerInto is any tensor that can slice into another tensor. +// The other tensor may already have data allocated in it. +// If that is the case then the slice will be a copy operation. +type SlicerInto interface { + SliceInto(view Tensor, slices ...Slice) (retVal Tensor, err error) +} + +// Reslicer is any tensor that can reslice. +// To reslice is to reuse the container (*Dense, *CS) etc, but with new `Slice`s applied to it. +// +// e.g: A is a (3,3) matrix that has been sliced at [1:3, 1:3]. Call it B. So now B's shape is (2,2). +// B.Reslice(S(0,2), S(0,2)) would reslice the original tensor (A) with the new slices. +type Reslicer interface { + Reslice(...Slice) (View, error) +} + // DenseTensor is the interface for any Dense tensor. type DenseTensor interface { Tensor Info() *AP - DataOrder() DataOrder IsMatrix() bool IsVector() bool IsRowVec() bool @@ -89,6 +105,7 @@ type DenseTensor interface { rtype() reflect.Type reshape(dims ...int) error + setDataOrder(o DataOrder) isTransposed() bool ostrides() []int oshape() Shape @@ -131,6 +148,11 @@ type Kinder interface { Kind() reflect.Kind } +// MakeAliker is any Tensor that can make more like itself. +type MakeAliker interface { + MakeAike(opts ...ConsOpt) Tensor +} + type headerer interface { hdr() *storage.Header } @@ -144,6 +166,17 @@ type unsafeMem interface { Set(i int, x interface{}) GetF64(i int) float64 GetF32(i int) float32 + Ints() []int Float64s() []float64 Float32s() []float32 + Complex64s() []complex64 + Complex128s() []complex128 +} + +type float64ser interface { + Float64s() []float64 +} + +type float32ser interface { + Float32s() []float32 } diff --git a/internal/IDLs/generated.fbs b/internal/IDLs/generated.fbs new file mode 100644 index 0000000..47ffce2 --- /dev/null +++ b/internal/IDLs/generated.fbs @@ -0,0 +1,38 @@ +// Generated from generated.proto + +namespace gorgonia.org.tensor.internal.serialization.pb; + +enum Triangle : int { + NOT_TRIANGLE = 0, + UPPER = 1, + LOWER = 2, + SYMMETRIC = 3, +} + +table AP { + shape:[int]; + strides:[int]; + o:uint; + t:gorgonia.org.tensor.internal.serialization.pb.Triangle; +} + +table Dense { + shape:[int]; + strides:[int]; + o:uint; + t:gorgonia.org.tensor.internal.serialization.pb.Triangle; + type:string; + data:[ubyte]; +} + +table MaskedDense { + shape:[int]; + strides:[int]; + o:uint; + t:gorgonia.org.tensor.internal.serialization.pb.Triangle; + type:string; + data:[ubyte]; + mask:[bool]; + mask_is_soft:[bool]; +} + diff --git a/internal/IDLs/generated.proto b/internal/IDLs/generated.proto new file mode 100755 index 0000000..c737106 --- /dev/null +++ b/internal/IDLs/generated.proto @@ -0,0 +1,52 @@ +syntax = "proto3"; +package gorgonia.org.tensor.internal.serialization.pb; + +import "github.com/gogo/protobuf/gogoproto/gogo.proto"; + +option (gogoproto.protosizer_all) = true; +option (gogoproto.sizer_all) = false; +option go_package = "pb"; + +message AP { + option (gogoproto.goproto_getters) = false; + option (gogoproto.typedecl) = false; + repeated int32 shape = 1; + repeated int32 strides = 2; + uint32 o = 3 [(gogoproto.casttype) = "DataOrder"]; + gorgonia.org.tensor.internal.serialization.pb.Triangle t = 4; +} + +message Dense { + option (gogoproto.goproto_getters) = false; + option (gogoproto.typedecl) = false; + repeated int32 shape = 1; + repeated int32 strides = 2; + uint32 o = 3 [(gogoproto.casttype) = "DataOrder"]; + gorgonia.org.tensor.internal.serialization.pb.Triangle t = 4; + string type = 5; + bytes data = 6; +} + +message MaskedDense { + option (gogoproto.goproto_getters) = false; + option (gogoproto.typedecl) = false; + repeated int32 shape = 1; + repeated int32 strides = 2; + uint32 o = 3 [(gogoproto.casttype) = "DataOrder"]; + gorgonia.org.tensor.internal.serialization.pb.Triangle t = 4; + string type = 5; + bytes data = 6; + repeated bool mask = 7; + repeated bool mask_is_soft = 8; +} + +enum Triangle { + option (gogoproto.enumdecl) = false; + option (gogoproto.goproto_enum_prefix) = false; + option (gogoproto.goproto_enum_stringer) = false; + NOT_TRIANGLE = 0 [(gogoproto.enumvalue_customname) = "NotTriangle"]; + UPPER = 1 [(gogoproto.enumvalue_customname) = "Upper"]; + LOWER = 2 [(gogoproto.enumvalue_customname) = "Lower"]; + SYMMETRIC = 3 [(gogoproto.enumvalue_customname) = "Symmetric"]; +} + diff --git a/internal/execution/e.go b/internal/execution/e.go index 670ae0b..83fcc1f 100644 --- a/internal/execution/e.go +++ b/internal/execution/e.go @@ -38,7 +38,7 @@ var ( UnsafePointer = reflect.TypeOf(unsafe.Pointer(&Uintptr)) ) -func isScalar(a *storage.Header) bool { return a.L == 1 } +func isScalar(a *storage.Header, t reflect.Type) bool { return a.TypedLen(t) == 1 } type errorIndices []int diff --git a/internal/execution/eng_argmethods.go b/internal/execution/eng_argmethods.go index 05ed725..9adc173 100644 --- a/internal/execution/eng_argmethods.go +++ b/internal/execution/eng_argmethods.go @@ -1,5 +1,3 @@ -// Code generated by genlib2. DO NOT EDIT. - package execution import ( @@ -9,6 +7,8 @@ import ( "gorgonia.org/tensor/internal/storage" ) +// Code generated by genlib2. DO NOT EDIT. + func (e E) ArgmaxIter(t reflect.Type, a *storage.Header, it Iterator, lastSize int) (indices []int, err error) { var next int switch t { diff --git a/internal/execution/eng_arith.go b/internal/execution/eng_arith.go index f626a3d..9b193ba 100644 --- a/internal/execution/eng_arith.go +++ b/internal/execution/eng_arith.go @@ -1,5 +1,3 @@ -// Code generated by genlib2. DO NOT EDIT. - package execution import ( @@ -9,9 +7,11 @@ import ( "gorgonia.org/tensor/internal/storage" ) +// Code generated by genlib2. DO NOT EDIT. + func (e E) Add(t reflect.Type, a *storage.Header, b *storage.Header) (err error) { - as := isScalar(a) - bs := isScalar(b) + as := isScalar(a, t) + bs := isScalar(b, t) switch t { case Int: @@ -230,8 +230,8 @@ func (e E) Add(t reflect.Type, a *storage.Header, b *storage.Header) (err error) } func (e E) Sub(t reflect.Type, a *storage.Header, b *storage.Header) (err error) { - as := isScalar(a) - bs := isScalar(b) + as := isScalar(a, t) + bs := isScalar(b, t) switch t { case Int: @@ -436,8 +436,8 @@ func (e E) Sub(t reflect.Type, a *storage.Header, b *storage.Header) (err error) } func (e E) Mul(t reflect.Type, a *storage.Header, b *storage.Header) (err error) { - as := isScalar(a) - bs := isScalar(b) + as := isScalar(a, t) + bs := isScalar(b, t) switch t { case Int: @@ -642,8 +642,8 @@ func (e E) Mul(t reflect.Type, a *storage.Header, b *storage.Header) (err error) } func (e E) Div(t reflect.Type, a *storage.Header, b *storage.Header) (err error) { - as := isScalar(a) - bs := isScalar(b) + as := isScalar(a, t) + bs := isScalar(b, t) switch t { case Int: @@ -848,8 +848,8 @@ func (e E) Div(t reflect.Type, a *storage.Header, b *storage.Header) (err error) } func (e E) Pow(t reflect.Type, a *storage.Header, b *storage.Header) (err error) { - as := isScalar(a) - bs := isScalar(b) + as := isScalar(a, t) + bs := isScalar(b, t) switch t { case Float32: @@ -914,8 +914,8 @@ func (e E) Pow(t reflect.Type, a *storage.Header, b *storage.Header) (err error) } func (e E) Mod(t reflect.Type, a *storage.Header, b *storage.Header) (err error) { - as := isScalar(a) - bs := isScalar(b) + as := isScalar(a, t) + bs := isScalar(b, t) switch t { case Int: @@ -1092,11 +1092,11 @@ func (e E) Mod(t reflect.Type, a *storage.Header, b *storage.Header) (err error) } func (e E) AddIncr(t reflect.Type, a *storage.Header, b *storage.Header, incr *storage.Header) (err error) { - as := isScalar(a) - bs := isScalar(b) - is := isScalar(incr) + as := isScalar(a, t) + bs := isScalar(b, t) + is := isScalar(incr, t) if ((as && !bs) || (bs && !as)) && is { - return errors.Errorf("Cannot increment on scalar increment. a: %d, b %d", a.Len(), b.Len()) + return errors.Errorf("Cannot increment on scalar increment. a: %d, b %d", a.TypedLen(t), b.TypedLen(t)) } switch t { @@ -1406,11 +1406,11 @@ func (e E) AddIncr(t reflect.Type, a *storage.Header, b *storage.Header, incr *s } func (e E) SubIncr(t reflect.Type, a *storage.Header, b *storage.Header, incr *storage.Header) (err error) { - as := isScalar(a) - bs := isScalar(b) - is := isScalar(incr) + as := isScalar(a, t) + bs := isScalar(b, t) + is := isScalar(incr, t) if ((as && !bs) || (bs && !as)) && is { - return errors.Errorf("Cannot increment on scalar increment. a: %d, b %d", a.Len(), b.Len()) + return errors.Errorf("Cannot increment on scalar increment. a: %d, b %d", a.TypedLen(t), b.TypedLen(t)) } switch t { @@ -1700,11 +1700,11 @@ func (e E) SubIncr(t reflect.Type, a *storage.Header, b *storage.Header, incr *s } func (e E) MulIncr(t reflect.Type, a *storage.Header, b *storage.Header, incr *storage.Header) (err error) { - as := isScalar(a) - bs := isScalar(b) - is := isScalar(incr) + as := isScalar(a, t) + bs := isScalar(b, t) + is := isScalar(incr, t) if ((as && !bs) || (bs && !as)) && is { - return errors.Errorf("Cannot increment on scalar increment. a: %d, b %d", a.Len(), b.Len()) + return errors.Errorf("Cannot increment on scalar increment. a: %d, b %d", a.TypedLen(t), b.TypedLen(t)) } switch t { @@ -1994,11 +1994,11 @@ func (e E) MulIncr(t reflect.Type, a *storage.Header, b *storage.Header, incr *s } func (e E) DivIncr(t reflect.Type, a *storage.Header, b *storage.Header, incr *storage.Header) (err error) { - as := isScalar(a) - bs := isScalar(b) - is := isScalar(incr) + as := isScalar(a, t) + bs := isScalar(b, t) + is := isScalar(incr, t) if ((as && !bs) || (bs && !as)) && is { - return errors.Errorf("Cannot increment on scalar increment. a: %d, b %d", a.Len(), b.Len()) + return errors.Errorf("Cannot increment on scalar increment. a: %d, b %d", a.TypedLen(t), b.TypedLen(t)) } switch t { @@ -2288,11 +2288,11 @@ func (e E) DivIncr(t reflect.Type, a *storage.Header, b *storage.Header, incr *s } func (e E) PowIncr(t reflect.Type, a *storage.Header, b *storage.Header, incr *storage.Header) (err error) { - as := isScalar(a) - bs := isScalar(b) - is := isScalar(incr) + as := isScalar(a, t) + bs := isScalar(b, t) + is := isScalar(incr, t) if ((as && !bs) || (bs && !as)) && is { - return errors.Errorf("Cannot increment on scalar increment. a: %d, b %d", a.Len(), b.Len()) + return errors.Errorf("Cannot increment on scalar increment. a: %d, b %d", a.TypedLen(t), b.TypedLen(t)) } switch t { @@ -2382,11 +2382,11 @@ func (e E) PowIncr(t reflect.Type, a *storage.Header, b *storage.Header, incr *s } func (e E) ModIncr(t reflect.Type, a *storage.Header, b *storage.Header, incr *storage.Header) (err error) { - as := isScalar(a) - bs := isScalar(b) - is := isScalar(incr) + as := isScalar(a, t) + bs := isScalar(b, t) + is := isScalar(incr, t) if ((as && !bs) || (bs && !as)) && is { - return errors.Errorf("Cannot increment on scalar increment. a: %d, b %d", a.Len(), b.Len()) + return errors.Errorf("Cannot increment on scalar increment. a: %d, b %d", a.TypedLen(t), b.TypedLen(t)) } switch t { @@ -2636,8 +2636,8 @@ func (e E) ModIncr(t reflect.Type, a *storage.Header, b *storage.Header, incr *s } func (e E) AddIter(t reflect.Type, a *storage.Header, b *storage.Header, ait Iterator, bit Iterator) (err error) { - as := isScalar(a) - bs := isScalar(b) + as := isScalar(a, t) + bs := isScalar(b, t) switch t { case Int: @@ -2851,13 +2851,13 @@ func (e E) AddIter(t reflect.Type, a *storage.Header, b *storage.Header, ait Ite } return default: - return errors.Errorf("Unsupported type %v for Add", t) + return errors.Errorf("Unsupported type %v for AddIter", t) } } func (e E) SubIter(t reflect.Type, a *storage.Header, b *storage.Header, ait Iterator, bit Iterator) (err error) { - as := isScalar(a) - bs := isScalar(b) + as := isScalar(a, t) + bs := isScalar(b, t) switch t { case Int: @@ -3057,13 +3057,13 @@ func (e E) SubIter(t reflect.Type, a *storage.Header, b *storage.Header, ait Ite } return default: - return errors.Errorf("Unsupported type %v for Sub", t) + return errors.Errorf("Unsupported type %v for SubIter", t) } } func (e E) MulIter(t reflect.Type, a *storage.Header, b *storage.Header, ait Iterator, bit Iterator) (err error) { - as := isScalar(a) - bs := isScalar(b) + as := isScalar(a, t) + bs := isScalar(b, t) switch t { case Int: @@ -3263,13 +3263,13 @@ func (e E) MulIter(t reflect.Type, a *storage.Header, b *storage.Header, ait Ite } return default: - return errors.Errorf("Unsupported type %v for Mul", t) + return errors.Errorf("Unsupported type %v for MulIter", t) } } func (e E) DivIter(t reflect.Type, a *storage.Header, b *storage.Header, ait Iterator, bit Iterator) (err error) { - as := isScalar(a) - bs := isScalar(b) + as := isScalar(a, t) + bs := isScalar(b, t) switch t { case Int: @@ -3469,13 +3469,13 @@ func (e E) DivIter(t reflect.Type, a *storage.Header, b *storage.Header, ait Ite } return default: - return errors.Errorf("Unsupported type %v for Div", t) + return errors.Errorf("Unsupported type %v for DivIter", t) } } func (e E) PowIter(t reflect.Type, a *storage.Header, b *storage.Header, ait Iterator, bit Iterator) (err error) { - as := isScalar(a) - bs := isScalar(b) + as := isScalar(a, t) + bs := isScalar(b, t) switch t { case Float32: @@ -3535,13 +3535,13 @@ func (e E) PowIter(t reflect.Type, a *storage.Header, b *storage.Header, ait Ite } return default: - return errors.Errorf("Unsupported type %v for Pow", t) + return errors.Errorf("Unsupported type %v for PowIter", t) } } func (e E) ModIter(t reflect.Type, a *storage.Header, b *storage.Header, ait Iterator, bit Iterator) (err error) { - as := isScalar(a) - bs := isScalar(b) + as := isScalar(a, t) + bs := isScalar(b, t) switch t { case Int: @@ -3713,17 +3713,17 @@ func (e E) ModIter(t reflect.Type, a *storage.Header, b *storage.Header, ait Ite } return default: - return errors.Errorf("Unsupported type %v for Mod", t) + return errors.Errorf("Unsupported type %v for ModIter", t) } } func (e E) AddIterIncr(t reflect.Type, a *storage.Header, b *storage.Header, incr *storage.Header, ait Iterator, bit Iterator, iit Iterator) (err error) { - as := isScalar(a) - bs := isScalar(b) - is := isScalar(incr) + as := isScalar(a, t) + bs := isScalar(b, t) + is := isScalar(incr, t) if ((as && !bs) || (bs && !as)) && is { - return errors.Errorf("Cannot increment on a scalar increment. len(a): %d, len(b) %d", a.Len(), b.Len()) + return errors.Errorf("Cannot increment on a scalar increment. len(a): %d, len(b) %d", a.TypedLen(t), b.TypedLen(t)) } switch t { @@ -4013,17 +4013,17 @@ func (e E) AddIterIncr(t reflect.Type, a *storage.Header, b *storage.Header, inc return AddIterIncrStr(at, bt, it, ait, bit, iit) } default: - return errors.Errorf("Unsupported type %v for Add", t) + return errors.Errorf("Unsupported type %v for AddIterIncr", t) } } func (e E) SubIterIncr(t reflect.Type, a *storage.Header, b *storage.Header, incr *storage.Header, ait Iterator, bit Iterator, iit Iterator) (err error) { - as := isScalar(a) - bs := isScalar(b) - is := isScalar(incr) + as := isScalar(a, t) + bs := isScalar(b, t) + is := isScalar(incr, t) if ((as && !bs) || (bs && !as)) && is { - return errors.Errorf("Cannot increment on a scalar increment. len(a): %d, len(b) %d", a.Len(), b.Len()) + return errors.Errorf("Cannot increment on a scalar increment. len(a): %d, len(b) %d", a.TypedLen(t), b.TypedLen(t)) } switch t { @@ -4294,17 +4294,17 @@ func (e E) SubIterIncr(t reflect.Type, a *storage.Header, b *storage.Header, inc return SubIterIncrC128(at, bt, it, ait, bit, iit) } default: - return errors.Errorf("Unsupported type %v for Sub", t) + return errors.Errorf("Unsupported type %v for SubIterIncr", t) } } func (e E) MulIterIncr(t reflect.Type, a *storage.Header, b *storage.Header, incr *storage.Header, ait Iterator, bit Iterator, iit Iterator) (err error) { - as := isScalar(a) - bs := isScalar(b) - is := isScalar(incr) + as := isScalar(a, t) + bs := isScalar(b, t) + is := isScalar(incr, t) if ((as && !bs) || (bs && !as)) && is { - return errors.Errorf("Cannot increment on a scalar increment. len(a): %d, len(b) %d", a.Len(), b.Len()) + return errors.Errorf("Cannot increment on a scalar increment. len(a): %d, len(b) %d", a.TypedLen(t), b.TypedLen(t)) } switch t { @@ -4575,17 +4575,17 @@ func (e E) MulIterIncr(t reflect.Type, a *storage.Header, b *storage.Header, inc return MulIterIncrC128(at, bt, it, ait, bit, iit) } default: - return errors.Errorf("Unsupported type %v for Mul", t) + return errors.Errorf("Unsupported type %v for MulIterIncr", t) } } func (e E) DivIterIncr(t reflect.Type, a *storage.Header, b *storage.Header, incr *storage.Header, ait Iterator, bit Iterator, iit Iterator) (err error) { - as := isScalar(a) - bs := isScalar(b) - is := isScalar(incr) + as := isScalar(a, t) + bs := isScalar(b, t) + is := isScalar(incr, t) if ((as && !bs) || (bs && !as)) && is { - return errors.Errorf("Cannot increment on a scalar increment. len(a): %d, len(b) %d", a.Len(), b.Len()) + return errors.Errorf("Cannot increment on a scalar increment. len(a): %d, len(b) %d", a.TypedLen(t), b.TypedLen(t)) } switch t { @@ -4856,17 +4856,17 @@ func (e E) DivIterIncr(t reflect.Type, a *storage.Header, b *storage.Header, inc return DivIterIncrC128(at, bt, it, ait, bit, iit) } default: - return errors.Errorf("Unsupported type %v for Div", t) + return errors.Errorf("Unsupported type %v for DivIterIncr", t) } } func (e E) PowIterIncr(t reflect.Type, a *storage.Header, b *storage.Header, incr *storage.Header, ait Iterator, bit Iterator, iit Iterator) (err error) { - as := isScalar(a) - bs := isScalar(b) - is := isScalar(incr) + as := isScalar(a, t) + bs := isScalar(b, t) + is := isScalar(incr, t) if ((as && !bs) || (bs && !as)) && is { - return errors.Errorf("Cannot increment on a scalar increment. len(a): %d, len(b) %d", a.Len(), b.Len()) + return errors.Errorf("Cannot increment on a scalar increment. len(a): %d, len(b) %d", a.TypedLen(t), b.TypedLen(t)) } switch t { @@ -4947,17 +4947,17 @@ func (e E) PowIterIncr(t reflect.Type, a *storage.Header, b *storage.Header, inc return PowIterIncrC128(at, bt, it, ait, bit, iit) } default: - return errors.Errorf("Unsupported type %v for Pow", t) + return errors.Errorf("Unsupported type %v for PowIterIncr", t) } } func (e E) ModIterIncr(t reflect.Type, a *storage.Header, b *storage.Header, incr *storage.Header, ait Iterator, bit Iterator, iit Iterator) (err error) { - as := isScalar(a) - bs := isScalar(b) - is := isScalar(incr) + as := isScalar(a, t) + bs := isScalar(b, t) + is := isScalar(incr, t) if ((as && !bs) || (bs && !as)) && is { - return errors.Errorf("Cannot increment on a scalar increment. len(a): %d, len(b) %d", a.Len(), b.Len()) + return errors.Errorf("Cannot increment on a scalar increment. len(a): %d, len(b) %d", a.TypedLen(t), b.TypedLen(t)) } switch t { @@ -5190,6 +5190,534 @@ func (e E) ModIterIncr(t reflect.Type, a *storage.Header, b *storage.Header, inc return ModIterIncrF64(at, bt, it, ait, bit, iit) } default: - return errors.Errorf("Unsupported type %v for Mod", t) + return errors.Errorf("Unsupported type %v for ModIterIncr", t) + } +} + +func (e E) AddRecv(t reflect.Type, a *storage.Header, b *storage.Header, recv *storage.Header) (err error) { + as := isScalar(a, t) + bs := isScalar(b, t) + rs := isScalar(recv, t) + + if ((as && !bs) || (bs && !as)) && rs { + return errors.Errorf("Cannot increment on a scalar increment. len(a): %d, len(b) %d", a.TypedLen(t), b.TypedLen(t)) + } + + switch t { + case Int: + at := a.Ints() + bt := b.Ints() + rt := recv.Ints() + AddRecvI(at, bt, rt) + return + case Int8: + at := a.Int8s() + bt := b.Int8s() + rt := recv.Int8s() + AddRecvI8(at, bt, rt) + return + case Int16: + at := a.Int16s() + bt := b.Int16s() + rt := recv.Int16s() + AddRecvI16(at, bt, rt) + return + case Int32: + at := a.Int32s() + bt := b.Int32s() + rt := recv.Int32s() + AddRecvI32(at, bt, rt) + return + case Int64: + at := a.Int64s() + bt := b.Int64s() + rt := recv.Int64s() + AddRecvI64(at, bt, rt) + return + case Uint: + at := a.Uints() + bt := b.Uints() + rt := recv.Uints() + AddRecvU(at, bt, rt) + return + case Uint8: + at := a.Uint8s() + bt := b.Uint8s() + rt := recv.Uint8s() + AddRecvU8(at, bt, rt) + return + case Uint16: + at := a.Uint16s() + bt := b.Uint16s() + rt := recv.Uint16s() + AddRecvU16(at, bt, rt) + return + case Uint32: + at := a.Uint32s() + bt := b.Uint32s() + rt := recv.Uint32s() + AddRecvU32(at, bt, rt) + return + case Uint64: + at := a.Uint64s() + bt := b.Uint64s() + rt := recv.Uint64s() + AddRecvU64(at, bt, rt) + return + case Float32: + at := a.Float32s() + bt := b.Float32s() + rt := recv.Float32s() + AddRecvF32(at, bt, rt) + return + case Float64: + at := a.Float64s() + bt := b.Float64s() + rt := recv.Float64s() + AddRecvF64(at, bt, rt) + return + case Complex64: + at := a.Complex64s() + bt := b.Complex64s() + rt := recv.Complex64s() + AddRecvC64(at, bt, rt) + return + case Complex128: + at := a.Complex128s() + bt := b.Complex128s() + rt := recv.Complex128s() + AddRecvC128(at, bt, rt) + return + case String: + at := a.Strings() + bt := b.Strings() + rt := recv.Strings() + AddRecvStr(at, bt, rt) + return + default: + return errors.Errorf("Unsupported type %v for AddRecv", t) + } +} + +func (e E) SubRecv(t reflect.Type, a *storage.Header, b *storage.Header, recv *storage.Header) (err error) { + as := isScalar(a, t) + bs := isScalar(b, t) + rs := isScalar(recv, t) + + if ((as && !bs) || (bs && !as)) && rs { + return errors.Errorf("Cannot increment on a scalar increment. len(a): %d, len(b) %d", a.TypedLen(t), b.TypedLen(t)) + } + + switch t { + case Int: + at := a.Ints() + bt := b.Ints() + rt := recv.Ints() + SubRecvI(at, bt, rt) + return + case Int8: + at := a.Int8s() + bt := b.Int8s() + rt := recv.Int8s() + SubRecvI8(at, bt, rt) + return + case Int16: + at := a.Int16s() + bt := b.Int16s() + rt := recv.Int16s() + SubRecvI16(at, bt, rt) + return + case Int32: + at := a.Int32s() + bt := b.Int32s() + rt := recv.Int32s() + SubRecvI32(at, bt, rt) + return + case Int64: + at := a.Int64s() + bt := b.Int64s() + rt := recv.Int64s() + SubRecvI64(at, bt, rt) + return + case Uint: + at := a.Uints() + bt := b.Uints() + rt := recv.Uints() + SubRecvU(at, bt, rt) + return + case Uint8: + at := a.Uint8s() + bt := b.Uint8s() + rt := recv.Uint8s() + SubRecvU8(at, bt, rt) + return + case Uint16: + at := a.Uint16s() + bt := b.Uint16s() + rt := recv.Uint16s() + SubRecvU16(at, bt, rt) + return + case Uint32: + at := a.Uint32s() + bt := b.Uint32s() + rt := recv.Uint32s() + SubRecvU32(at, bt, rt) + return + case Uint64: + at := a.Uint64s() + bt := b.Uint64s() + rt := recv.Uint64s() + SubRecvU64(at, bt, rt) + return + case Float32: + at := a.Float32s() + bt := b.Float32s() + rt := recv.Float32s() + SubRecvF32(at, bt, rt) + return + case Float64: + at := a.Float64s() + bt := b.Float64s() + rt := recv.Float64s() + SubRecvF64(at, bt, rt) + return + case Complex64: + at := a.Complex64s() + bt := b.Complex64s() + rt := recv.Complex64s() + SubRecvC64(at, bt, rt) + return + case Complex128: + at := a.Complex128s() + bt := b.Complex128s() + rt := recv.Complex128s() + SubRecvC128(at, bt, rt) + return + default: + return errors.Errorf("Unsupported type %v for SubRecv", t) + } +} + +func (e E) MulRecv(t reflect.Type, a *storage.Header, b *storage.Header, recv *storage.Header) (err error) { + as := isScalar(a, t) + bs := isScalar(b, t) + rs := isScalar(recv, t) + + if ((as && !bs) || (bs && !as)) && rs { + return errors.Errorf("Cannot increment on a scalar increment. len(a): %d, len(b) %d", a.TypedLen(t), b.TypedLen(t)) + } + + switch t { + case Int: + at := a.Ints() + bt := b.Ints() + rt := recv.Ints() + MulRecvI(at, bt, rt) + return + case Int8: + at := a.Int8s() + bt := b.Int8s() + rt := recv.Int8s() + MulRecvI8(at, bt, rt) + return + case Int16: + at := a.Int16s() + bt := b.Int16s() + rt := recv.Int16s() + MulRecvI16(at, bt, rt) + return + case Int32: + at := a.Int32s() + bt := b.Int32s() + rt := recv.Int32s() + MulRecvI32(at, bt, rt) + return + case Int64: + at := a.Int64s() + bt := b.Int64s() + rt := recv.Int64s() + MulRecvI64(at, bt, rt) + return + case Uint: + at := a.Uints() + bt := b.Uints() + rt := recv.Uints() + MulRecvU(at, bt, rt) + return + case Uint8: + at := a.Uint8s() + bt := b.Uint8s() + rt := recv.Uint8s() + MulRecvU8(at, bt, rt) + return + case Uint16: + at := a.Uint16s() + bt := b.Uint16s() + rt := recv.Uint16s() + MulRecvU16(at, bt, rt) + return + case Uint32: + at := a.Uint32s() + bt := b.Uint32s() + rt := recv.Uint32s() + MulRecvU32(at, bt, rt) + return + case Uint64: + at := a.Uint64s() + bt := b.Uint64s() + rt := recv.Uint64s() + MulRecvU64(at, bt, rt) + return + case Float32: + at := a.Float32s() + bt := b.Float32s() + rt := recv.Float32s() + MulRecvF32(at, bt, rt) + return + case Float64: + at := a.Float64s() + bt := b.Float64s() + rt := recv.Float64s() + MulRecvF64(at, bt, rt) + return + case Complex64: + at := a.Complex64s() + bt := b.Complex64s() + rt := recv.Complex64s() + MulRecvC64(at, bt, rt) + return + case Complex128: + at := a.Complex128s() + bt := b.Complex128s() + rt := recv.Complex128s() + MulRecvC128(at, bt, rt) + return + default: + return errors.Errorf("Unsupported type %v for MulRecv", t) + } +} + +func (e E) DivRecv(t reflect.Type, a *storage.Header, b *storage.Header, recv *storage.Header) (err error) { + as := isScalar(a, t) + bs := isScalar(b, t) + rs := isScalar(recv, t) + + if ((as && !bs) || (bs && !as)) && rs { + return errors.Errorf("Cannot increment on a scalar increment. len(a): %d, len(b) %d", a.TypedLen(t), b.TypedLen(t)) + } + + switch t { + case Int: + at := a.Ints() + bt := b.Ints() + rt := recv.Ints() + DivRecvI(at, bt, rt) + return + case Int8: + at := a.Int8s() + bt := b.Int8s() + rt := recv.Int8s() + DivRecvI8(at, bt, rt) + return + case Int16: + at := a.Int16s() + bt := b.Int16s() + rt := recv.Int16s() + DivRecvI16(at, bt, rt) + return + case Int32: + at := a.Int32s() + bt := b.Int32s() + rt := recv.Int32s() + DivRecvI32(at, bt, rt) + return + case Int64: + at := a.Int64s() + bt := b.Int64s() + rt := recv.Int64s() + DivRecvI64(at, bt, rt) + return + case Uint: + at := a.Uints() + bt := b.Uints() + rt := recv.Uints() + DivRecvU(at, bt, rt) + return + case Uint8: + at := a.Uint8s() + bt := b.Uint8s() + rt := recv.Uint8s() + DivRecvU8(at, bt, rt) + return + case Uint16: + at := a.Uint16s() + bt := b.Uint16s() + rt := recv.Uint16s() + DivRecvU16(at, bt, rt) + return + case Uint32: + at := a.Uint32s() + bt := b.Uint32s() + rt := recv.Uint32s() + DivRecvU32(at, bt, rt) + return + case Uint64: + at := a.Uint64s() + bt := b.Uint64s() + rt := recv.Uint64s() + DivRecvU64(at, bt, rt) + return + case Float32: + at := a.Float32s() + bt := b.Float32s() + rt := recv.Float32s() + DivRecvF32(at, bt, rt) + return + case Float64: + at := a.Float64s() + bt := b.Float64s() + rt := recv.Float64s() + DivRecvF64(at, bt, rt) + return + case Complex64: + at := a.Complex64s() + bt := b.Complex64s() + rt := recv.Complex64s() + DivRecvC64(at, bt, rt) + return + case Complex128: + at := a.Complex128s() + bt := b.Complex128s() + rt := recv.Complex128s() + DivRecvC128(at, bt, rt) + return + default: + return errors.Errorf("Unsupported type %v for DivRecv", t) + } +} + +func (e E) PowRecv(t reflect.Type, a *storage.Header, b *storage.Header, recv *storage.Header) (err error) { + as := isScalar(a, t) + bs := isScalar(b, t) + rs := isScalar(recv, t) + + if ((as && !bs) || (bs && !as)) && rs { + return errors.Errorf("Cannot increment on a scalar increment. len(a): %d, len(b) %d", a.TypedLen(t), b.TypedLen(t)) + } + + switch t { + case Float32: + at := a.Float32s() + bt := b.Float32s() + rt := recv.Float32s() + PowRecvF32(at, bt, rt) + return + case Float64: + at := a.Float64s() + bt := b.Float64s() + rt := recv.Float64s() + PowRecvF64(at, bt, rt) + return + case Complex64: + at := a.Complex64s() + bt := b.Complex64s() + rt := recv.Complex64s() + PowRecvC64(at, bt, rt) + return + case Complex128: + at := a.Complex128s() + bt := b.Complex128s() + rt := recv.Complex128s() + PowRecvC128(at, bt, rt) + return + default: + return errors.Errorf("Unsupported type %v for PowRecv", t) + } +} + +func (e E) ModRecv(t reflect.Type, a *storage.Header, b *storage.Header, recv *storage.Header) (err error) { + as := isScalar(a, t) + bs := isScalar(b, t) + rs := isScalar(recv, t) + + if ((as && !bs) || (bs && !as)) && rs { + return errors.Errorf("Cannot increment on a scalar increment. len(a): %d, len(b) %d", a.TypedLen(t), b.TypedLen(t)) + } + + switch t { + case Int: + at := a.Ints() + bt := b.Ints() + rt := recv.Ints() + ModRecvI(at, bt, rt) + return + case Int8: + at := a.Int8s() + bt := b.Int8s() + rt := recv.Int8s() + ModRecvI8(at, bt, rt) + return + case Int16: + at := a.Int16s() + bt := b.Int16s() + rt := recv.Int16s() + ModRecvI16(at, bt, rt) + return + case Int32: + at := a.Int32s() + bt := b.Int32s() + rt := recv.Int32s() + ModRecvI32(at, bt, rt) + return + case Int64: + at := a.Int64s() + bt := b.Int64s() + rt := recv.Int64s() + ModRecvI64(at, bt, rt) + return + case Uint: + at := a.Uints() + bt := b.Uints() + rt := recv.Uints() + ModRecvU(at, bt, rt) + return + case Uint8: + at := a.Uint8s() + bt := b.Uint8s() + rt := recv.Uint8s() + ModRecvU8(at, bt, rt) + return + case Uint16: + at := a.Uint16s() + bt := b.Uint16s() + rt := recv.Uint16s() + ModRecvU16(at, bt, rt) + return + case Uint32: + at := a.Uint32s() + bt := b.Uint32s() + rt := recv.Uint32s() + ModRecvU32(at, bt, rt) + return + case Uint64: + at := a.Uint64s() + bt := b.Uint64s() + rt := recv.Uint64s() + ModRecvU64(at, bt, rt) + return + case Float32: + at := a.Float32s() + bt := b.Float32s() + rt := recv.Float32s() + ModRecvF32(at, bt, rt) + return + case Float64: + at := a.Float64s() + bt := b.Float64s() + rt := recv.Float64s() + ModRecvF64(at, bt, rt) + return + default: + return errors.Errorf("Unsupported type %v for ModRecv", t) } } diff --git a/internal/execution/eng_arith_manual.go b/internal/execution/eng_arith_manual.go new file mode 100644 index 0000000..3a620e6 --- /dev/null +++ b/internal/execution/eng_arith_manual.go @@ -0,0 +1,241 @@ +package execution + +import ( + "reflect" + + "github.com/pkg/errors" + "gorgonia.org/tensor/internal/storage" +) + +func (e E) AddSliced(t reflect.Type, dataA *storage.Header, dstStart, dstEnd int, dataB *storage.Header, srcStart, srcEnd int) (err error) { + ds := dstStart * int(t.Size()) + de := dstEnd * int(t.Size()) + a := &storage.Header{ + Raw: dataA.Raw[ds:de], + } + + ss := srcStart * int(t.Size()) + se := srcEnd * int(t.Size()) + b := &storage.Header{ + Raw: dataB.Raw[ss:se], + } + + as := isScalar(a, t) + bs := isScalar(b, t) + + switch t { + case Int: + at := a.Ints() + bt := b.Ints() + + switch { + case as && bs: + VecAddI(at, bt) + case as && !bs: + AddSVI(at[0], bt) + case !as && bs: + AddVSI(at, bt[0]) + default: + VecAddI(at, bt) + } + return + case Int8: + at := a.Int8s() + bt := b.Int8s() + switch { + case as && bs: + VecAddI8(at, bt) + case as && !bs: + AddSVI8(at[0], bt) + case !as && bs: + AddVSI8(at, bt[0]) + default: + VecAddI8(at, bt) + } + return + case Int16: + at := a.Int16s() + bt := b.Int16s() + switch { + case as && bs: + VecAddI16(at, bt) + case as && !bs: + AddSVI16(at[0], bt) + case !as && bs: + AddVSI16(at, bt[0]) + default: + VecAddI16(at, bt) + } + return + case Int32: + at := a.Int32s() + bt := b.Int32s() + switch { + case as && bs: + VecAddI32(at, bt) + case as && !bs: + AddSVI32(at[0], bt) + case !as && bs: + AddVSI32(at, bt[0]) + default: + VecAddI32(at, bt) + } + return + case Int64: + at := a.Int64s() + bt := b.Int64s() + switch { + case as && bs: + VecAddI64(at, bt) + case as && !bs: + AddSVI64(at[0], bt) + case !as && bs: + AddVSI64(at, bt[0]) + default: + VecAddI64(at, bt) + } + return + case Uint: + at := a.Uints() + bt := b.Uints() + switch { + case as && bs: + VecAddU(at, bt) + case as && !bs: + AddSVU(at[0], bt) + case !as && bs: + AddVSU(at, bt[0]) + default: + VecAddU(at, bt) + } + return + case Uint8: + at := a.Uint8s() + bt := b.Uint8s() + switch { + case as && bs: + VecAddU8(at, bt) + case as && !bs: + AddSVU8(at[0], bt) + case !as && bs: + AddVSU8(at, bt[0]) + default: + VecAddU8(at, bt) + } + return + case Uint16: + at := a.Uint16s() + bt := b.Uint16s() + switch { + case as && bs: + VecAddU16(at, bt) + case as && !bs: + AddSVU16(at[0], bt) + case !as && bs: + AddVSU16(at, bt[0]) + default: + VecAddU16(at, bt) + } + return + case Uint32: + at := a.Uint32s() + bt := b.Uint32s() + switch { + case as && bs: + VecAddU32(at, bt) + case as && !bs: + AddSVU32(at[0], bt) + case !as && bs: + AddVSU32(at, bt[0]) + default: + VecAddU32(at, bt) + } + return + case Uint64: + at := a.Uint64s() + bt := b.Uint64s() + switch { + case as && bs: + VecAddU64(at, bt) + case as && !bs: + AddSVU64(at[0], bt) + case !as && bs: + AddVSU64(at, bt[0]) + default: + VecAddU64(at, bt) + } + return + case Float32: + at := a.Float32s() + bt := b.Float32s() + switch { + case as && bs: + VecAddF32(at, bt) + case as && !bs: + AddSVF32(at[0], bt) + case !as && bs: + AddVSF32(at, bt[0]) + default: + VecAddF32(at, bt) + } + return + case Float64: + at := a.Float64s() + bt := b.Float64s() + switch { + case as && bs: + VecAddF64(at, bt) + case as && !bs: + AddSVF64(at[0], bt) + case !as && bs: + AddVSF64(at, bt[0]) + default: + VecAddF64(at, bt) + } + return + case Complex64: + at := a.Complex64s() + bt := b.Complex64s() + switch { + case as && bs: + VecAddC64(at, bt) + case as && !bs: + AddSVC64(at[0], bt) + case !as && bs: + AddVSC64(at, bt[0]) + default: + VecAddC64(at, bt) + } + return + case Complex128: + at := a.Complex128s() + bt := b.Complex128s() + switch { + case as && bs: + VecAddC128(at, bt) + case as && !bs: + AddSVC128(at[0], bt) + case !as && bs: + AddVSC128(at, bt[0]) + default: + VecAddC128(at, bt) + } + return + case String: + at := a.Strings() + bt := b.Strings() + switch { + case as && bs: + VecAddStr(at, bt) + case as && !bs: + AddSVStr(at[0], bt) + case !as && bs: + AddVSStr(at, bt[0]) + default: + VecAddStr(at, bt) + } + return + default: + return errors.Errorf("Unsupported type %v for Add", t) + } +} diff --git a/internal/execution/eng_cmp.go b/internal/execution/eng_cmp.go index 9514f61..e5d3dd5 100644 --- a/internal/execution/eng_cmp.go +++ b/internal/execution/eng_cmp.go @@ -1,5 +1,3 @@ -// Code generated by genlib2. DO NOT EDIT. - package execution import ( @@ -9,14 +7,16 @@ import ( "gorgonia.org/tensor/internal/storage" ) +// Code generated by genlib2. DO NOT EDIT. + func (e E) Gt(t reflect.Type, a *storage.Header, b *storage.Header, retVal *storage.Header) (err error) { - as := isScalar(a) - bs := isScalar(b) - rs := isScalar(retVal) + as := isScalar(a, t) + bs := isScalar(b, t) + rs := isScalar(retVal, Bool) rt := retVal.Bools() if ((as && !bs) || (bs && !as)) && rs { - return errors.Errorf("retVal is a scalar. a: %d, b %d", a.Len(), b.Len()) + return errors.Errorf("retVal is a scalar. a: %d, b %d", a.TypedLen(t), b.TypedLen(t)) } switch t { @@ -221,13 +221,13 @@ func (e E) Gt(t reflect.Type, a *storage.Header, b *storage.Header, retVal *stor } func (e E) Gte(t reflect.Type, a *storage.Header, b *storage.Header, retVal *storage.Header) (err error) { - as := isScalar(a) - bs := isScalar(b) - rs := isScalar(retVal) + as := isScalar(a, t) + bs := isScalar(b, t) + rs := isScalar(retVal, Bool) rt := retVal.Bools() if ((as && !bs) || (bs && !as)) && rs { - return errors.Errorf("retVal is a scalar. a: %d, b %d", a.Len(), b.Len()) + return errors.Errorf("retVal is a scalar. a: %d, b %d", a.TypedLen(t), b.TypedLen(t)) } switch t { @@ -432,13 +432,13 @@ func (e E) Gte(t reflect.Type, a *storage.Header, b *storage.Header, retVal *sto } func (e E) Lt(t reflect.Type, a *storage.Header, b *storage.Header, retVal *storage.Header) (err error) { - as := isScalar(a) - bs := isScalar(b) - rs := isScalar(retVal) + as := isScalar(a, t) + bs := isScalar(b, t) + rs := isScalar(retVal, Bool) rt := retVal.Bools() if ((as && !bs) || (bs && !as)) && rs { - return errors.Errorf("retVal is a scalar. a: %d, b %d", a.Len(), b.Len()) + return errors.Errorf("retVal is a scalar. a: %d, b %d", a.TypedLen(t), b.TypedLen(t)) } switch t { @@ -643,13 +643,13 @@ func (e E) Lt(t reflect.Type, a *storage.Header, b *storage.Header, retVal *stor } func (e E) Lte(t reflect.Type, a *storage.Header, b *storage.Header, retVal *storage.Header) (err error) { - as := isScalar(a) - bs := isScalar(b) - rs := isScalar(retVal) + as := isScalar(a, t) + bs := isScalar(b, t) + rs := isScalar(retVal, Bool) rt := retVal.Bools() if ((as && !bs) || (bs && !as)) && rs { - return errors.Errorf("retVal is a scalar. a: %d, b %d", a.Len(), b.Len()) + return errors.Errorf("retVal is a scalar. a: %d, b %d", a.TypedLen(t), b.TypedLen(t)) } switch t { @@ -854,13 +854,13 @@ func (e E) Lte(t reflect.Type, a *storage.Header, b *storage.Header, retVal *sto } func (e E) Eq(t reflect.Type, a *storage.Header, b *storage.Header, retVal *storage.Header) (err error) { - as := isScalar(a) - bs := isScalar(b) - rs := isScalar(retVal) + as := isScalar(a, t) + bs := isScalar(b, t) + rs := isScalar(retVal, Bool) rt := retVal.Bools() if ((as && !bs) || (bs && !as)) && rs { - return errors.Errorf("retVal is a scalar. a: %d, b %d", a.Len(), b.Len()) + return errors.Errorf("retVal is a scalar. a: %d, b %d", a.TypedLen(t), b.TypedLen(t)) } switch t { @@ -1140,13 +1140,13 @@ func (e E) Eq(t reflect.Type, a *storage.Header, b *storage.Header, retVal *stor } func (e E) Ne(t reflect.Type, a *storage.Header, b *storage.Header, retVal *storage.Header) (err error) { - as := isScalar(a) - bs := isScalar(b) - rs := isScalar(retVal) + as := isScalar(a, t) + bs := isScalar(b, t) + rs := isScalar(retVal, Bool) rt := retVal.Bools() if ((as && !bs) || (bs && !as)) && rs { - return errors.Errorf("retVal is a scalar. a: %d, b %d", a.Len(), b.Len()) + return errors.Errorf("retVal is a scalar. a: %d, b %d", a.TypedLen(t), b.TypedLen(t)) } switch t { @@ -1426,8 +1426,8 @@ func (e E) Ne(t reflect.Type, a *storage.Header, b *storage.Header, retVal *stor } func (e E) GtSame(t reflect.Type, a *storage.Header, b *storage.Header) (err error) { - as := isScalar(a) - bs := isScalar(b) + as := isScalar(a, t) + bs := isScalar(b, t) switch t { case Int: @@ -1618,8 +1618,8 @@ func (e E) GtSame(t reflect.Type, a *storage.Header, b *storage.Header) (err err } func (e E) GteSame(t reflect.Type, a *storage.Header, b *storage.Header) (err error) { - as := isScalar(a) - bs := isScalar(b) + as := isScalar(a, t) + bs := isScalar(b, t) switch t { case Int: @@ -1810,8 +1810,8 @@ func (e E) GteSame(t reflect.Type, a *storage.Header, b *storage.Header) (err er } func (e E) LtSame(t reflect.Type, a *storage.Header, b *storage.Header) (err error) { - as := isScalar(a) - bs := isScalar(b) + as := isScalar(a, t) + bs := isScalar(b, t) switch t { case Int: @@ -2002,8 +2002,8 @@ func (e E) LtSame(t reflect.Type, a *storage.Header, b *storage.Header) (err err } func (e E) LteSame(t reflect.Type, a *storage.Header, b *storage.Header) (err error) { - as := isScalar(a) - bs := isScalar(b) + as := isScalar(a, t) + bs := isScalar(b, t) switch t { case Int: @@ -2194,8 +2194,8 @@ func (e E) LteSame(t reflect.Type, a *storage.Header, b *storage.Header) (err er } func (e E) EqSame(t reflect.Type, a *storage.Header, b *storage.Header) (err error) { - as := isScalar(a) - bs := isScalar(b) + as := isScalar(a, t) + bs := isScalar(b, t) switch t { case Bool: @@ -2442,8 +2442,8 @@ func (e E) EqSame(t reflect.Type, a *storage.Header, b *storage.Header) (err err } func (e E) NeSame(t reflect.Type, a *storage.Header, b *storage.Header) (err error) { - as := isScalar(a) - bs := isScalar(b) + as := isScalar(a, t) + bs := isScalar(b, t) switch t { case Bool: @@ -2690,13 +2690,13 @@ func (e E) NeSame(t reflect.Type, a *storage.Header, b *storage.Header) (err err } func (e E) GtIter(t reflect.Type, a *storage.Header, b *storage.Header, retVal *storage.Header, ait Iterator, bit Iterator, rit Iterator) (err error) { - as := isScalar(a) - bs := isScalar(b) - rs := isScalar(retVal) + as := isScalar(a, t) + bs := isScalar(b, t) + rs := isScalar(retVal, Bool) rt := retVal.Bools() if ((as && !bs) || (bs && !as)) && rs { - return errors.Errorf("retVal is scalar while len(a): %d, len(b) %d", a.Len(), b.Len()) + return errors.Errorf("retVal is scalar while len(a): %d, len(b) %d", a.TypedLen(t), b.TypedLen(t)) } switch t { @@ -2888,13 +2888,13 @@ func (e E) GtIter(t reflect.Type, a *storage.Header, b *storage.Header, retVal * } func (e E) GteIter(t reflect.Type, a *storage.Header, b *storage.Header, retVal *storage.Header, ait Iterator, bit Iterator, rit Iterator) (err error) { - as := isScalar(a) - bs := isScalar(b) - rs := isScalar(retVal) + as := isScalar(a, t) + bs := isScalar(b, t) + rs := isScalar(retVal, Bool) rt := retVal.Bools() if ((as && !bs) || (bs && !as)) && rs { - return errors.Errorf("retVal is scalar while len(a): %d, len(b) %d", a.Len(), b.Len()) + return errors.Errorf("retVal is scalar while len(a): %d, len(b) %d", a.TypedLen(t), b.TypedLen(t)) } switch t { @@ -3086,13 +3086,13 @@ func (e E) GteIter(t reflect.Type, a *storage.Header, b *storage.Header, retVal } func (e E) LtIter(t reflect.Type, a *storage.Header, b *storage.Header, retVal *storage.Header, ait Iterator, bit Iterator, rit Iterator) (err error) { - as := isScalar(a) - bs := isScalar(b) - rs := isScalar(retVal) + as := isScalar(a, t) + bs := isScalar(b, t) + rs := isScalar(retVal, Bool) rt := retVal.Bools() if ((as && !bs) || (bs && !as)) && rs { - return errors.Errorf("retVal is scalar while len(a): %d, len(b) %d", a.Len(), b.Len()) + return errors.Errorf("retVal is scalar while len(a): %d, len(b) %d", a.TypedLen(t), b.TypedLen(t)) } switch t { @@ -3284,13 +3284,13 @@ func (e E) LtIter(t reflect.Type, a *storage.Header, b *storage.Header, retVal * } func (e E) LteIter(t reflect.Type, a *storage.Header, b *storage.Header, retVal *storage.Header, ait Iterator, bit Iterator, rit Iterator) (err error) { - as := isScalar(a) - bs := isScalar(b) - rs := isScalar(retVal) + as := isScalar(a, t) + bs := isScalar(b, t) + rs := isScalar(retVal, Bool) rt := retVal.Bools() if ((as && !bs) || (bs && !as)) && rs { - return errors.Errorf("retVal is scalar while len(a): %d, len(b) %d", a.Len(), b.Len()) + return errors.Errorf("retVal is scalar while len(a): %d, len(b) %d", a.TypedLen(t), b.TypedLen(t)) } switch t { @@ -3482,13 +3482,13 @@ func (e E) LteIter(t reflect.Type, a *storage.Header, b *storage.Header, retVal } func (e E) EqIter(t reflect.Type, a *storage.Header, b *storage.Header, retVal *storage.Header, ait Iterator, bit Iterator, rit Iterator) (err error) { - as := isScalar(a) - bs := isScalar(b) - rs := isScalar(retVal) + as := isScalar(a, t) + bs := isScalar(b, t) + rs := isScalar(retVal, Bool) rt := retVal.Bools() if ((as && !bs) || (bs && !as)) && rs { - return errors.Errorf("retVal is scalar while len(a): %d, len(b) %d", a.Len(), b.Len()) + return errors.Errorf("retVal is scalar while len(a): %d, len(b) %d", a.TypedLen(t), b.TypedLen(t)) } switch t { @@ -3750,13 +3750,13 @@ func (e E) EqIter(t reflect.Type, a *storage.Header, b *storage.Header, retVal * } func (e E) NeIter(t reflect.Type, a *storage.Header, b *storage.Header, retVal *storage.Header, ait Iterator, bit Iterator, rit Iterator) (err error) { - as := isScalar(a) - bs := isScalar(b) - rs := isScalar(retVal) + as := isScalar(a, t) + bs := isScalar(b, t) + rs := isScalar(retVal, Bool) rt := retVal.Bools() if ((as && !bs) || (bs && !as)) && rs { - return errors.Errorf("retVal is scalar while len(a): %d, len(b) %d", a.Len(), b.Len()) + return errors.Errorf("retVal is scalar while len(a): %d, len(b) %d", a.TypedLen(t), b.TypedLen(t)) } switch t { @@ -4018,8 +4018,8 @@ func (e E) NeIter(t reflect.Type, a *storage.Header, b *storage.Header, retVal * } func (e E) GtSameIter(t reflect.Type, a *storage.Header, b *storage.Header, ait Iterator, bit Iterator) (err error) { - as := isScalar(a) - bs := isScalar(b) + as := isScalar(a, t) + bs := isScalar(b, t) switch t { case Int: @@ -4210,8 +4210,8 @@ func (e E) GtSameIter(t reflect.Type, a *storage.Header, b *storage.Header, ait } func (e E) GteSameIter(t reflect.Type, a *storage.Header, b *storage.Header, ait Iterator, bit Iterator) (err error) { - as := isScalar(a) - bs := isScalar(b) + as := isScalar(a, t) + bs := isScalar(b, t) switch t { case Int: @@ -4402,8 +4402,8 @@ func (e E) GteSameIter(t reflect.Type, a *storage.Header, b *storage.Header, ait } func (e E) LtSameIter(t reflect.Type, a *storage.Header, b *storage.Header, ait Iterator, bit Iterator) (err error) { - as := isScalar(a) - bs := isScalar(b) + as := isScalar(a, t) + bs := isScalar(b, t) switch t { case Int: @@ -4594,8 +4594,8 @@ func (e E) LtSameIter(t reflect.Type, a *storage.Header, b *storage.Header, ait } func (e E) LteSameIter(t reflect.Type, a *storage.Header, b *storage.Header, ait Iterator, bit Iterator) (err error) { - as := isScalar(a) - bs := isScalar(b) + as := isScalar(a, t) + bs := isScalar(b, t) switch t { case Int: @@ -4786,8 +4786,8 @@ func (e E) LteSameIter(t reflect.Type, a *storage.Header, b *storage.Header, ait } func (e E) EqSameIter(t reflect.Type, a *storage.Header, b *storage.Header, ait Iterator, bit Iterator) (err error) { - as := isScalar(a) - bs := isScalar(b) + as := isScalar(a, t) + bs := isScalar(b, t) switch t { case Bool: @@ -5034,8 +5034,8 @@ func (e E) EqSameIter(t reflect.Type, a *storage.Header, b *storage.Header, ait } func (e E) NeSameIter(t reflect.Type, a *storage.Header, b *storage.Header, ait Iterator, bit Iterator) (err error) { - as := isScalar(a) - bs := isScalar(b) + as := isScalar(a, t) + bs := isScalar(b, t) switch t { case Bool: diff --git a/internal/execution/eng_map.go b/internal/execution/eng_map.go index 17ca682..ecd2b64 100644 --- a/internal/execution/eng_map.go +++ b/internal/execution/eng_map.go @@ -1,5 +1,3 @@ -// Code generated by genlib2. DO NOT EDIT. - package execution import ( @@ -10,8 +8,10 @@ import ( "gorgonia.org/tensor/internal/storage" ) +// Code generated by genlib2. DO NOT EDIT. + func (e E) Map(t reflect.Type, fn interface{}, a *storage.Header, incr bool) (err error) { - as := isScalar(a) + as := isScalar(a, t) switch t { case Bool: var f0 func(bool) bool diff --git a/internal/execution/eng_minmaxbetween.go b/internal/execution/eng_minmaxbetween.go new file mode 100644 index 0000000..8c41606 --- /dev/null +++ b/internal/execution/eng_minmaxbetween.go @@ -0,0 +1,778 @@ +package execution + +import ( + "reflect" + + "github.com/pkg/errors" + "gorgonia.org/tensor/internal/storage" +) + +// Code generated by genlib2. DO NOT EDIT. + +func (e E) MaxBetween(t reflect.Type, a *storage.Header, b *storage.Header) (err error) { + as := isScalar(a, t) + bs := isScalar(b, t) + + switch t { + case Int: + at := a.Ints() + bt := b.Ints() + switch { + case as && bs: + VecMaxI(at, bt) + case as && !bs: + MaxSVI(at[0], bt) + case !as && bs: + MaxVSI(at, bt[0]) + default: + VecMaxI(at, bt) + } + return + case Int8: + at := a.Int8s() + bt := b.Int8s() + switch { + case as && bs: + VecMaxI8(at, bt) + case as && !bs: + MaxSVI8(at[0], bt) + case !as && bs: + MaxVSI8(at, bt[0]) + default: + VecMaxI8(at, bt) + } + return + case Int16: + at := a.Int16s() + bt := b.Int16s() + switch { + case as && bs: + VecMaxI16(at, bt) + case as && !bs: + MaxSVI16(at[0], bt) + case !as && bs: + MaxVSI16(at, bt[0]) + default: + VecMaxI16(at, bt) + } + return + case Int32: + at := a.Int32s() + bt := b.Int32s() + switch { + case as && bs: + VecMaxI32(at, bt) + case as && !bs: + MaxSVI32(at[0], bt) + case !as && bs: + MaxVSI32(at, bt[0]) + default: + VecMaxI32(at, bt) + } + return + case Int64: + at := a.Int64s() + bt := b.Int64s() + switch { + case as && bs: + VecMaxI64(at, bt) + case as && !bs: + MaxSVI64(at[0], bt) + case !as && bs: + MaxVSI64(at, bt[0]) + default: + VecMaxI64(at, bt) + } + return + case Uint: + at := a.Uints() + bt := b.Uints() + switch { + case as && bs: + VecMaxU(at, bt) + case as && !bs: + MaxSVU(at[0], bt) + case !as && bs: + MaxVSU(at, bt[0]) + default: + VecMaxU(at, bt) + } + return + case Uint8: + at := a.Uint8s() + bt := b.Uint8s() + switch { + case as && bs: + VecMaxU8(at, bt) + case as && !bs: + MaxSVU8(at[0], bt) + case !as && bs: + MaxVSU8(at, bt[0]) + default: + VecMaxU8(at, bt) + } + return + case Uint16: + at := a.Uint16s() + bt := b.Uint16s() + switch { + case as && bs: + VecMaxU16(at, bt) + case as && !bs: + MaxSVU16(at[0], bt) + case !as && bs: + MaxVSU16(at, bt[0]) + default: + VecMaxU16(at, bt) + } + return + case Uint32: + at := a.Uint32s() + bt := b.Uint32s() + switch { + case as && bs: + VecMaxU32(at, bt) + case as && !bs: + MaxSVU32(at[0], bt) + case !as && bs: + MaxVSU32(at, bt[0]) + default: + VecMaxU32(at, bt) + } + return + case Uint64: + at := a.Uint64s() + bt := b.Uint64s() + switch { + case as && bs: + VecMaxU64(at, bt) + case as && !bs: + MaxSVU64(at[0], bt) + case !as && bs: + MaxVSU64(at, bt[0]) + default: + VecMaxU64(at, bt) + } + return + case Float32: + at := a.Float32s() + bt := b.Float32s() + switch { + case as && bs: + VecMaxF32(at, bt) + case as && !bs: + MaxSVF32(at[0], bt) + case !as && bs: + MaxVSF32(at, bt[0]) + default: + VecMaxF32(at, bt) + } + return + case Float64: + at := a.Float64s() + bt := b.Float64s() + switch { + case as && bs: + VecMaxF64(at, bt) + case as && !bs: + MaxSVF64(at[0], bt) + case !as && bs: + MaxVSF64(at, bt[0]) + default: + VecMaxF64(at, bt) + } + return + case String: + at := a.Strings() + bt := b.Strings() + switch { + case as && bs: + VecMaxStr(at, bt) + case as && !bs: + MaxSVStr(at[0], bt) + case !as && bs: + MaxVSStr(at, bt[0]) + default: + VecMaxStr(at, bt) + } + return + default: + return errors.Errorf("Unsupported type %v for Max", t) + } +} + +func (e E) MinBetween(t reflect.Type, a *storage.Header, b *storage.Header) (err error) { + as := isScalar(a, t) + bs := isScalar(b, t) + + switch t { + case Int: + at := a.Ints() + bt := b.Ints() + switch { + case as && bs: + VecMinI(at, bt) + case as && !bs: + MinSVI(at[0], bt) + case !as && bs: + MinVSI(at, bt[0]) + default: + VecMinI(at, bt) + } + return + case Int8: + at := a.Int8s() + bt := b.Int8s() + switch { + case as && bs: + VecMinI8(at, bt) + case as && !bs: + MinSVI8(at[0], bt) + case !as && bs: + MinVSI8(at, bt[0]) + default: + VecMinI8(at, bt) + } + return + case Int16: + at := a.Int16s() + bt := b.Int16s() + switch { + case as && bs: + VecMinI16(at, bt) + case as && !bs: + MinSVI16(at[0], bt) + case !as && bs: + MinVSI16(at, bt[0]) + default: + VecMinI16(at, bt) + } + return + case Int32: + at := a.Int32s() + bt := b.Int32s() + switch { + case as && bs: + VecMinI32(at, bt) + case as && !bs: + MinSVI32(at[0], bt) + case !as && bs: + MinVSI32(at, bt[0]) + default: + VecMinI32(at, bt) + } + return + case Int64: + at := a.Int64s() + bt := b.Int64s() + switch { + case as && bs: + VecMinI64(at, bt) + case as && !bs: + MinSVI64(at[0], bt) + case !as && bs: + MinVSI64(at, bt[0]) + default: + VecMinI64(at, bt) + } + return + case Uint: + at := a.Uints() + bt := b.Uints() + switch { + case as && bs: + VecMinU(at, bt) + case as && !bs: + MinSVU(at[0], bt) + case !as && bs: + MinVSU(at, bt[0]) + default: + VecMinU(at, bt) + } + return + case Uint8: + at := a.Uint8s() + bt := b.Uint8s() + switch { + case as && bs: + VecMinU8(at, bt) + case as && !bs: + MinSVU8(at[0], bt) + case !as && bs: + MinVSU8(at, bt[0]) + default: + VecMinU8(at, bt) + } + return + case Uint16: + at := a.Uint16s() + bt := b.Uint16s() + switch { + case as && bs: + VecMinU16(at, bt) + case as && !bs: + MinSVU16(at[0], bt) + case !as && bs: + MinVSU16(at, bt[0]) + default: + VecMinU16(at, bt) + } + return + case Uint32: + at := a.Uint32s() + bt := b.Uint32s() + switch { + case as && bs: + VecMinU32(at, bt) + case as && !bs: + MinSVU32(at[0], bt) + case !as && bs: + MinVSU32(at, bt[0]) + default: + VecMinU32(at, bt) + } + return + case Uint64: + at := a.Uint64s() + bt := b.Uint64s() + switch { + case as && bs: + VecMinU64(at, bt) + case as && !bs: + MinSVU64(at[0], bt) + case !as && bs: + MinVSU64(at, bt[0]) + default: + VecMinU64(at, bt) + } + return + case Float32: + at := a.Float32s() + bt := b.Float32s() + switch { + case as && bs: + VecMinF32(at, bt) + case as && !bs: + MinSVF32(at[0], bt) + case !as && bs: + MinVSF32(at, bt[0]) + default: + VecMinF32(at, bt) + } + return + case Float64: + at := a.Float64s() + bt := b.Float64s() + switch { + case as && bs: + VecMinF64(at, bt) + case as && !bs: + MinSVF64(at[0], bt) + case !as && bs: + MinVSF64(at, bt[0]) + default: + VecMinF64(at, bt) + } + return + case String: + at := a.Strings() + bt := b.Strings() + switch { + case as && bs: + VecMinStr(at, bt) + case as && !bs: + MinSVStr(at[0], bt) + case !as && bs: + MinVSStr(at, bt[0]) + default: + VecMinStr(at, bt) + } + return + default: + return errors.Errorf("Unsupported type %v for Min", t) + } +} + +func (e E) MaxBetweenIter(t reflect.Type, a *storage.Header, b *storage.Header, ait Iterator, bit Iterator) (err error) { + as := isScalar(a, t) + bs := isScalar(b, t) + + switch t { + case Int: + at := a.Ints() + bt := b.Ints() + switch { + case as && bs: + VecMaxI(at, bt) + case as && !bs: + MaxIterSVI(at[0], bt, bit) + case !as && bs: + MaxIterVSI(at, bt[0], ait) + default: + VecMaxIterI(at, bt, ait, bit) + } + return + case Int8: + at := a.Int8s() + bt := b.Int8s() + switch { + case as && bs: + VecMaxI8(at, bt) + case as && !bs: + MaxIterSVI8(at[0], bt, bit) + case !as && bs: + MaxIterVSI8(at, bt[0], ait) + default: + VecMaxIterI8(at, bt, ait, bit) + } + return + case Int16: + at := a.Int16s() + bt := b.Int16s() + switch { + case as && bs: + VecMaxI16(at, bt) + case as && !bs: + MaxIterSVI16(at[0], bt, bit) + case !as && bs: + MaxIterVSI16(at, bt[0], ait) + default: + VecMaxIterI16(at, bt, ait, bit) + } + return + case Int32: + at := a.Int32s() + bt := b.Int32s() + switch { + case as && bs: + VecMaxI32(at, bt) + case as && !bs: + MaxIterSVI32(at[0], bt, bit) + case !as && bs: + MaxIterVSI32(at, bt[0], ait) + default: + VecMaxIterI32(at, bt, ait, bit) + } + return + case Int64: + at := a.Int64s() + bt := b.Int64s() + switch { + case as && bs: + VecMaxI64(at, bt) + case as && !bs: + MaxIterSVI64(at[0], bt, bit) + case !as && bs: + MaxIterVSI64(at, bt[0], ait) + default: + VecMaxIterI64(at, bt, ait, bit) + } + return + case Uint: + at := a.Uints() + bt := b.Uints() + switch { + case as && bs: + VecMaxU(at, bt) + case as && !bs: + MaxIterSVU(at[0], bt, bit) + case !as && bs: + MaxIterVSU(at, bt[0], ait) + default: + VecMaxIterU(at, bt, ait, bit) + } + return + case Uint8: + at := a.Uint8s() + bt := b.Uint8s() + switch { + case as && bs: + VecMaxU8(at, bt) + case as && !bs: + MaxIterSVU8(at[0], bt, bit) + case !as && bs: + MaxIterVSU8(at, bt[0], ait) + default: + VecMaxIterU8(at, bt, ait, bit) + } + return + case Uint16: + at := a.Uint16s() + bt := b.Uint16s() + switch { + case as && bs: + VecMaxU16(at, bt) + case as && !bs: + MaxIterSVU16(at[0], bt, bit) + case !as && bs: + MaxIterVSU16(at, bt[0], ait) + default: + VecMaxIterU16(at, bt, ait, bit) + } + return + case Uint32: + at := a.Uint32s() + bt := b.Uint32s() + switch { + case as && bs: + VecMaxU32(at, bt) + case as && !bs: + MaxIterSVU32(at[0], bt, bit) + case !as && bs: + MaxIterVSU32(at, bt[0], ait) + default: + VecMaxIterU32(at, bt, ait, bit) + } + return + case Uint64: + at := a.Uint64s() + bt := b.Uint64s() + switch { + case as && bs: + VecMaxU64(at, bt) + case as && !bs: + MaxIterSVU64(at[0], bt, bit) + case !as && bs: + MaxIterVSU64(at, bt[0], ait) + default: + VecMaxIterU64(at, bt, ait, bit) + } + return + case Float32: + at := a.Float32s() + bt := b.Float32s() + switch { + case as && bs: + VecMaxF32(at, bt) + case as && !bs: + MaxIterSVF32(at[0], bt, bit) + case !as && bs: + MaxIterVSF32(at, bt[0], ait) + default: + VecMaxIterF32(at, bt, ait, bit) + } + return + case Float64: + at := a.Float64s() + bt := b.Float64s() + switch { + case as && bs: + VecMaxF64(at, bt) + case as && !bs: + MaxIterSVF64(at[0], bt, bit) + case !as && bs: + MaxIterVSF64(at, bt[0], ait) + default: + VecMaxIterF64(at, bt, ait, bit) + } + return + case String: + at := a.Strings() + bt := b.Strings() + switch { + case as && bs: + VecMaxStr(at, bt) + case as && !bs: + MaxIterSVStr(at[0], bt, bit) + case !as && bs: + MaxIterVSStr(at, bt[0], ait) + default: + VecMaxIterStr(at, bt, ait, bit) + } + return + default: + return errors.Errorf("Unsupported type %v for Max", t) + } +} + +func (e E) MinBetweenIter(t reflect.Type, a *storage.Header, b *storage.Header, ait Iterator, bit Iterator) (err error) { + as := isScalar(a, t) + bs := isScalar(b, t) + + switch t { + case Int: + at := a.Ints() + bt := b.Ints() + switch { + case as && bs: + VecMinI(at, bt) + case as && !bs: + MinIterSVI(at[0], bt, bit) + case !as && bs: + MinIterVSI(at, bt[0], ait) + default: + VecMinIterI(at, bt, ait, bit) + } + return + case Int8: + at := a.Int8s() + bt := b.Int8s() + switch { + case as && bs: + VecMinI8(at, bt) + case as && !bs: + MinIterSVI8(at[0], bt, bit) + case !as && bs: + MinIterVSI8(at, bt[0], ait) + default: + VecMinIterI8(at, bt, ait, bit) + } + return + case Int16: + at := a.Int16s() + bt := b.Int16s() + switch { + case as && bs: + VecMinI16(at, bt) + case as && !bs: + MinIterSVI16(at[0], bt, bit) + case !as && bs: + MinIterVSI16(at, bt[0], ait) + default: + VecMinIterI16(at, bt, ait, bit) + } + return + case Int32: + at := a.Int32s() + bt := b.Int32s() + switch { + case as && bs: + VecMinI32(at, bt) + case as && !bs: + MinIterSVI32(at[0], bt, bit) + case !as && bs: + MinIterVSI32(at, bt[0], ait) + default: + VecMinIterI32(at, bt, ait, bit) + } + return + case Int64: + at := a.Int64s() + bt := b.Int64s() + switch { + case as && bs: + VecMinI64(at, bt) + case as && !bs: + MinIterSVI64(at[0], bt, bit) + case !as && bs: + MinIterVSI64(at, bt[0], ait) + default: + VecMinIterI64(at, bt, ait, bit) + } + return + case Uint: + at := a.Uints() + bt := b.Uints() + switch { + case as && bs: + VecMinU(at, bt) + case as && !bs: + MinIterSVU(at[0], bt, bit) + case !as && bs: + MinIterVSU(at, bt[0], ait) + default: + VecMinIterU(at, bt, ait, bit) + } + return + case Uint8: + at := a.Uint8s() + bt := b.Uint8s() + switch { + case as && bs: + VecMinU8(at, bt) + case as && !bs: + MinIterSVU8(at[0], bt, bit) + case !as && bs: + MinIterVSU8(at, bt[0], ait) + default: + VecMinIterU8(at, bt, ait, bit) + } + return + case Uint16: + at := a.Uint16s() + bt := b.Uint16s() + switch { + case as && bs: + VecMinU16(at, bt) + case as && !bs: + MinIterSVU16(at[0], bt, bit) + case !as && bs: + MinIterVSU16(at, bt[0], ait) + default: + VecMinIterU16(at, bt, ait, bit) + } + return + case Uint32: + at := a.Uint32s() + bt := b.Uint32s() + switch { + case as && bs: + VecMinU32(at, bt) + case as && !bs: + MinIterSVU32(at[0], bt, bit) + case !as && bs: + MinIterVSU32(at, bt[0], ait) + default: + VecMinIterU32(at, bt, ait, bit) + } + return + case Uint64: + at := a.Uint64s() + bt := b.Uint64s() + switch { + case as && bs: + VecMinU64(at, bt) + case as && !bs: + MinIterSVU64(at[0], bt, bit) + case !as && bs: + MinIterVSU64(at, bt[0], ait) + default: + VecMinIterU64(at, bt, ait, bit) + } + return + case Float32: + at := a.Float32s() + bt := b.Float32s() + switch { + case as && bs: + VecMinF32(at, bt) + case as && !bs: + MinIterSVF32(at[0], bt, bit) + case !as && bs: + MinIterVSF32(at, bt[0], ait) + default: + VecMinIterF32(at, bt, ait, bit) + } + return + case Float64: + at := a.Float64s() + bt := b.Float64s() + switch { + case as && bs: + VecMinF64(at, bt) + case as && !bs: + MinIterSVF64(at[0], bt, bit) + case !as && bs: + MinIterVSF64(at, bt[0], ait) + default: + VecMinIterF64(at, bt, ait, bit) + } + return + case String: + at := a.Strings() + bt := b.Strings() + switch { + case as && bs: + VecMinStr(at, bt) + case as && !bs: + MinIterSVStr(at[0], bt, bit) + case !as && bs: + MinIterVSStr(at, bt[0], ait) + default: + VecMinIterStr(at, bt, ait, bit) + } + return + default: + return errors.Errorf("Unsupported type %v for Min", t) + } +} diff --git a/internal/execution/eng_reduce.go b/internal/execution/eng_reduce.go index 88c7ae5..bebe52f 100644 --- a/internal/execution/eng_reduce.go +++ b/internal/execution/eng_reduce.go @@ -1,5 +1,3 @@ -// Code generated by genlib2. DO NOT EDIT. - package execution import ( @@ -10,6 +8,8 @@ import ( "gorgonia.org/tensor/internal/storage" ) +// Code generated by genlib2. DO NOT EDIT. + func (e E) ReduceFirst(t reflect.Type, data *storage.Header, retVal *storage.Header, split int, size int, fn interface{}) (err error) { switch t { case Bool: diff --git a/internal/execution/eng_unary.go b/internal/execution/eng_unary.go index bd9bd81..4038190 100644 --- a/internal/execution/eng_unary.go +++ b/internal/execution/eng_unary.go @@ -1,5 +1,3 @@ -// Code generated by genlib2. DO NOT EDIT. - package execution import ( @@ -9,6 +7,8 @@ import ( "gorgonia.org/tensor/internal/storage" ) +// Code generated by genlib2. DO NOT EDIT. + func (e E) Neg(t reflect.Type, a *storage.Header) (err error) { switch t { case Int: diff --git a/internal/execution/generic_argmethods.go b/internal/execution/generic_argmethods.go index 3edb606..cdf4b7d 100644 --- a/internal/execution/generic_argmethods.go +++ b/internal/execution/generic_argmethods.go @@ -1,5 +1,3 @@ -// Code generated by genlib2. DO NOT EDIT. - package execution import ( @@ -8,6 +6,8 @@ import ( "github.com/chewxy/math32" ) +// Code generated by genlib2. DO NOT EDIT. + func ArgmaxI(a []int) int { var set bool var f int diff --git a/internal/execution/generic_arith_mixed.go b/internal/execution/generic_arith_mixed.go index 6e8aa72..94f5e8b 100644 --- a/internal/execution/generic_arith_mixed.go +++ b/internal/execution/generic_arith_mixed.go @@ -1,5 +1,3 @@ -// Code generated by genlib2. DO NOT EDIT. - package execution import ( @@ -9,6 +7,8 @@ import ( "github.com/chewxy/math32" ) +// Code generated by genlib2. DO NOT EDIT. + func AddSVI(a int, b []int) { for i := range b { b[i] = a + b[i] diff --git a/internal/execution/generic_arith_vv.go b/internal/execution/generic_arith_vv.go index 26f3772..a9f3a7a 100644 --- a/internal/execution/generic_arith_vv.go +++ b/internal/execution/generic_arith_vv.go @@ -1,5 +1,3 @@ -// Code generated by genlib2. DO NOT EDIT. - package execution import ( @@ -11,6 +9,8 @@ import ( "gorgonia.org/vecf64" ) +// Code generated by genlib2. DO NOT EDIT. + func VecAddI(a []int, b []int) { a = a[:] b = b[:len(a)] @@ -4637,3 +4637,717 @@ func ModIterIncrF64(a []float64, b []float64, incr []float64, ait Iterator, bit } return } + +func AddRecvI(a []int, b []int, recv []int) { + a = a[:len(recv)] + b = b[:len(recv)] + for i := range recv { + recv[i] = a[i] + b[i] + } +} + +func AddRecvI8(a []int8, b []int8, recv []int8) { + a = a[:len(recv)] + b = b[:len(recv)] + for i := range recv { + recv[i] = a[i] + b[i] + } +} + +func AddRecvI16(a []int16, b []int16, recv []int16) { + a = a[:len(recv)] + b = b[:len(recv)] + for i := range recv { + recv[i] = a[i] + b[i] + } +} + +func AddRecvI32(a []int32, b []int32, recv []int32) { + a = a[:len(recv)] + b = b[:len(recv)] + for i := range recv { + recv[i] = a[i] + b[i] + } +} + +func AddRecvI64(a []int64, b []int64, recv []int64) { + a = a[:len(recv)] + b = b[:len(recv)] + for i := range recv { + recv[i] = a[i] + b[i] + } +} + +func AddRecvU(a []uint, b []uint, recv []uint) { + a = a[:len(recv)] + b = b[:len(recv)] + for i := range recv { + recv[i] = a[i] + b[i] + } +} + +func AddRecvU8(a []uint8, b []uint8, recv []uint8) { + a = a[:len(recv)] + b = b[:len(recv)] + for i := range recv { + recv[i] = a[i] + b[i] + } +} + +func AddRecvU16(a []uint16, b []uint16, recv []uint16) { + a = a[:len(recv)] + b = b[:len(recv)] + for i := range recv { + recv[i] = a[i] + b[i] + } +} + +func AddRecvU32(a []uint32, b []uint32, recv []uint32) { + a = a[:len(recv)] + b = b[:len(recv)] + for i := range recv { + recv[i] = a[i] + b[i] + } +} + +func AddRecvU64(a []uint64, b []uint64, recv []uint64) { + a = a[:len(recv)] + b = b[:len(recv)] + for i := range recv { + recv[i] = a[i] + b[i] + } +} + +func AddRecvF32(a []float32, b []float32, recv []float32) { + a = a[:len(recv)] + b = b[:len(recv)] + for i := range recv { + recv[i] = a[i] + b[i] + } +} + +func AddRecvF64(a []float64, b []float64, recv []float64) { + a = a[:len(recv)] + b = b[:len(recv)] + for i := range recv { + recv[i] = a[i] + b[i] + } +} + +func AddRecvC64(a []complex64, b []complex64, recv []complex64) { + a = a[:len(recv)] + b = b[:len(recv)] + for i := range recv { + recv[i] = a[i] + b[i] + } +} + +func AddRecvC128(a []complex128, b []complex128, recv []complex128) { + a = a[:len(recv)] + b = b[:len(recv)] + for i := range recv { + recv[i] = a[i] + b[i] + } +} + +func AddRecvStr(a []string, b []string, recv []string) { + a = a[:len(recv)] + b = b[:len(recv)] + for i := range recv { + recv[i] = a[i] + b[i] + } +} + +func SubRecvI(a []int, b []int, recv []int) { + a = a[:len(recv)] + b = b[:len(recv)] + for i := range recv { + recv[i] = a[i] - b[i] + } +} + +func SubRecvI8(a []int8, b []int8, recv []int8) { + a = a[:len(recv)] + b = b[:len(recv)] + for i := range recv { + recv[i] = a[i] - b[i] + } +} + +func SubRecvI16(a []int16, b []int16, recv []int16) { + a = a[:len(recv)] + b = b[:len(recv)] + for i := range recv { + recv[i] = a[i] - b[i] + } +} + +func SubRecvI32(a []int32, b []int32, recv []int32) { + a = a[:len(recv)] + b = b[:len(recv)] + for i := range recv { + recv[i] = a[i] - b[i] + } +} + +func SubRecvI64(a []int64, b []int64, recv []int64) { + a = a[:len(recv)] + b = b[:len(recv)] + for i := range recv { + recv[i] = a[i] - b[i] + } +} + +func SubRecvU(a []uint, b []uint, recv []uint) { + a = a[:len(recv)] + b = b[:len(recv)] + for i := range recv { + recv[i] = a[i] - b[i] + } +} + +func SubRecvU8(a []uint8, b []uint8, recv []uint8) { + a = a[:len(recv)] + b = b[:len(recv)] + for i := range recv { + recv[i] = a[i] - b[i] + } +} + +func SubRecvU16(a []uint16, b []uint16, recv []uint16) { + a = a[:len(recv)] + b = b[:len(recv)] + for i := range recv { + recv[i] = a[i] - b[i] + } +} + +func SubRecvU32(a []uint32, b []uint32, recv []uint32) { + a = a[:len(recv)] + b = b[:len(recv)] + for i := range recv { + recv[i] = a[i] - b[i] + } +} + +func SubRecvU64(a []uint64, b []uint64, recv []uint64) { + a = a[:len(recv)] + b = b[:len(recv)] + for i := range recv { + recv[i] = a[i] - b[i] + } +} + +func SubRecvF32(a []float32, b []float32, recv []float32) { + a = a[:len(recv)] + b = b[:len(recv)] + for i := range recv { + recv[i] = a[i] - b[i] + } +} + +func SubRecvF64(a []float64, b []float64, recv []float64) { + a = a[:len(recv)] + b = b[:len(recv)] + for i := range recv { + recv[i] = a[i] - b[i] + } +} + +func SubRecvC64(a []complex64, b []complex64, recv []complex64) { + a = a[:len(recv)] + b = b[:len(recv)] + for i := range recv { + recv[i] = a[i] - b[i] + } +} + +func SubRecvC128(a []complex128, b []complex128, recv []complex128) { + a = a[:len(recv)] + b = b[:len(recv)] + for i := range recv { + recv[i] = a[i] - b[i] + } +} + +func MulRecvI(a []int, b []int, recv []int) { + a = a[:len(recv)] + b = b[:len(recv)] + for i := range recv { + recv[i] = a[i] * b[i] + } +} + +func MulRecvI8(a []int8, b []int8, recv []int8) { + a = a[:len(recv)] + b = b[:len(recv)] + for i := range recv { + recv[i] = a[i] * b[i] + } +} + +func MulRecvI16(a []int16, b []int16, recv []int16) { + a = a[:len(recv)] + b = b[:len(recv)] + for i := range recv { + recv[i] = a[i] * b[i] + } +} + +func MulRecvI32(a []int32, b []int32, recv []int32) { + a = a[:len(recv)] + b = b[:len(recv)] + for i := range recv { + recv[i] = a[i] * b[i] + } +} + +func MulRecvI64(a []int64, b []int64, recv []int64) { + a = a[:len(recv)] + b = b[:len(recv)] + for i := range recv { + recv[i] = a[i] * b[i] + } +} + +func MulRecvU(a []uint, b []uint, recv []uint) { + a = a[:len(recv)] + b = b[:len(recv)] + for i := range recv { + recv[i] = a[i] * b[i] + } +} + +func MulRecvU8(a []uint8, b []uint8, recv []uint8) { + a = a[:len(recv)] + b = b[:len(recv)] + for i := range recv { + recv[i] = a[i] * b[i] + } +} + +func MulRecvU16(a []uint16, b []uint16, recv []uint16) { + a = a[:len(recv)] + b = b[:len(recv)] + for i := range recv { + recv[i] = a[i] * b[i] + } +} + +func MulRecvU32(a []uint32, b []uint32, recv []uint32) { + a = a[:len(recv)] + b = b[:len(recv)] + for i := range recv { + recv[i] = a[i] * b[i] + } +} + +func MulRecvU64(a []uint64, b []uint64, recv []uint64) { + a = a[:len(recv)] + b = b[:len(recv)] + for i := range recv { + recv[i] = a[i] * b[i] + } +} + +func MulRecvF32(a []float32, b []float32, recv []float32) { + a = a[:len(recv)] + b = b[:len(recv)] + for i := range recv { + recv[i] = a[i] * b[i] + } +} + +func MulRecvF64(a []float64, b []float64, recv []float64) { + a = a[:len(recv)] + b = b[:len(recv)] + for i := range recv { + recv[i] = a[i] * b[i] + } +} + +func MulRecvC64(a []complex64, b []complex64, recv []complex64) { + a = a[:len(recv)] + b = b[:len(recv)] + for i := range recv { + recv[i] = a[i] * b[i] + } +} + +func MulRecvC128(a []complex128, b []complex128, recv []complex128) { + a = a[:len(recv)] + b = b[:len(recv)] + for i := range recv { + recv[i] = a[i] * b[i] + } +} + +func DivRecvI(a []int, b []int, recv []int) (err error) { + a = a[:len(recv)] + b = b[:len(recv)] + var errs errorIndices + for i := range recv { + if b[i] == 0 { + errs = append(errs, i) + recv[i] = 0 + continue + } + recv[i] = a[i] / b[i] + } + if err != nil { + return + } + if len(errs) > 0 { + return errs + } + return nil +} + +func DivRecvI8(a []int8, b []int8, recv []int8) (err error) { + a = a[:len(recv)] + b = b[:len(recv)] + var errs errorIndices + for i := range recv { + if b[i] == 0 { + errs = append(errs, i) + recv[i] = 0 + continue + } + recv[i] = a[i] / b[i] + } + if err != nil { + return + } + if len(errs) > 0 { + return errs + } + return nil +} + +func DivRecvI16(a []int16, b []int16, recv []int16) (err error) { + a = a[:len(recv)] + b = b[:len(recv)] + var errs errorIndices + for i := range recv { + if b[i] == 0 { + errs = append(errs, i) + recv[i] = 0 + continue + } + recv[i] = a[i] / b[i] + } + if err != nil { + return + } + if len(errs) > 0 { + return errs + } + return nil +} + +func DivRecvI32(a []int32, b []int32, recv []int32) (err error) { + a = a[:len(recv)] + b = b[:len(recv)] + var errs errorIndices + for i := range recv { + if b[i] == 0 { + errs = append(errs, i) + recv[i] = 0 + continue + } + recv[i] = a[i] / b[i] + } + if err != nil { + return + } + if len(errs) > 0 { + return errs + } + return nil +} + +func DivRecvI64(a []int64, b []int64, recv []int64) (err error) { + a = a[:len(recv)] + b = b[:len(recv)] + var errs errorIndices + for i := range recv { + if b[i] == 0 { + errs = append(errs, i) + recv[i] = 0 + continue + } + recv[i] = a[i] / b[i] + } + if err != nil { + return + } + if len(errs) > 0 { + return errs + } + return nil +} + +func DivRecvU(a []uint, b []uint, recv []uint) (err error) { + a = a[:len(recv)] + b = b[:len(recv)] + var errs errorIndices + for i := range recv { + if b[i] == 0 { + errs = append(errs, i) + recv[i] = 0 + continue + } + recv[i] = a[i] / b[i] + } + if err != nil { + return + } + if len(errs) > 0 { + return errs + } + return nil +} + +func DivRecvU8(a []uint8, b []uint8, recv []uint8) (err error) { + a = a[:len(recv)] + b = b[:len(recv)] + var errs errorIndices + for i := range recv { + if b[i] == 0 { + errs = append(errs, i) + recv[i] = 0 + continue + } + recv[i] = a[i] / b[i] + } + if err != nil { + return + } + if len(errs) > 0 { + return errs + } + return nil +} + +func DivRecvU16(a []uint16, b []uint16, recv []uint16) (err error) { + a = a[:len(recv)] + b = b[:len(recv)] + var errs errorIndices + for i := range recv { + if b[i] == 0 { + errs = append(errs, i) + recv[i] = 0 + continue + } + recv[i] = a[i] / b[i] + } + if err != nil { + return + } + if len(errs) > 0 { + return errs + } + return nil +} + +func DivRecvU32(a []uint32, b []uint32, recv []uint32) (err error) { + a = a[:len(recv)] + b = b[:len(recv)] + var errs errorIndices + for i := range recv { + if b[i] == 0 { + errs = append(errs, i) + recv[i] = 0 + continue + } + recv[i] = a[i] / b[i] + } + if err != nil { + return + } + if len(errs) > 0 { + return errs + } + return nil +} + +func DivRecvU64(a []uint64, b []uint64, recv []uint64) (err error) { + a = a[:len(recv)] + b = b[:len(recv)] + var errs errorIndices + for i := range recv { + if b[i] == 0 { + errs = append(errs, i) + recv[i] = 0 + continue + } + recv[i] = a[i] / b[i] + } + if err != nil { + return + } + if len(errs) > 0 { + return errs + } + return nil +} + +func DivRecvF32(a []float32, b []float32, recv []float32) { + a = a[:len(recv)] + b = b[:len(recv)] + for i := range recv { + recv[i] = a[i] / b[i] + } +} + +func DivRecvF64(a []float64, b []float64, recv []float64) { + a = a[:len(recv)] + b = b[:len(recv)] + for i := range recv { + recv[i] = a[i] / b[i] + } +} + +func DivRecvC64(a []complex64, b []complex64, recv []complex64) { + a = a[:len(recv)] + b = b[:len(recv)] + for i := range recv { + recv[i] = a[i] / b[i] + } +} + +func DivRecvC128(a []complex128, b []complex128, recv []complex128) { + a = a[:len(recv)] + b = b[:len(recv)] + for i := range recv { + recv[i] = a[i] / b[i] + } +} + +func PowRecvF32(a []float32, b []float32, recv []float32) { + a = a[:len(recv)] + b = b[:len(recv)] + for i := range recv { + recv[i] = math32.Pow(a[i], b[i]) + } +} + +func PowRecvF64(a []float64, b []float64, recv []float64) { + a = a[:len(recv)] + b = b[:len(recv)] + for i := range recv { + recv[i] = math.Pow(a[i], b[i]) + } +} + +func PowRecvC64(a []complex64, b []complex64, recv []complex64) { + a = a[:len(recv)] + b = b[:len(recv)] + for i := range recv { + recv[i] = complex64(cmplx.Pow(complex128(a[i]), complex128(b[i]))) + } +} + +func PowRecvC128(a []complex128, b []complex128, recv []complex128) { + a = a[:len(recv)] + b = b[:len(recv)] + for i := range recv { + recv[i] = cmplx.Pow(a[i], b[i]) + } +} + +func ModRecvI(a []int, b []int, recv []int) { + a = a[:len(recv)] + b = b[:len(recv)] + for i := range recv { + recv[i] = a[i] % b[i] + } +} + +func ModRecvI8(a []int8, b []int8, recv []int8) { + a = a[:len(recv)] + b = b[:len(recv)] + for i := range recv { + recv[i] = a[i] % b[i] + } +} + +func ModRecvI16(a []int16, b []int16, recv []int16) { + a = a[:len(recv)] + b = b[:len(recv)] + for i := range recv { + recv[i] = a[i] % b[i] + } +} + +func ModRecvI32(a []int32, b []int32, recv []int32) { + a = a[:len(recv)] + b = b[:len(recv)] + for i := range recv { + recv[i] = a[i] % b[i] + } +} + +func ModRecvI64(a []int64, b []int64, recv []int64) { + a = a[:len(recv)] + b = b[:len(recv)] + for i := range recv { + recv[i] = a[i] % b[i] + } +} + +func ModRecvU(a []uint, b []uint, recv []uint) { + a = a[:len(recv)] + b = b[:len(recv)] + for i := range recv { + recv[i] = a[i] % b[i] + } +} + +func ModRecvU8(a []uint8, b []uint8, recv []uint8) { + a = a[:len(recv)] + b = b[:len(recv)] + for i := range recv { + recv[i] = a[i] % b[i] + } +} + +func ModRecvU16(a []uint16, b []uint16, recv []uint16) { + a = a[:len(recv)] + b = b[:len(recv)] + for i := range recv { + recv[i] = a[i] % b[i] + } +} + +func ModRecvU32(a []uint32, b []uint32, recv []uint32) { + a = a[:len(recv)] + b = b[:len(recv)] + for i := range recv { + recv[i] = a[i] % b[i] + } +} + +func ModRecvU64(a []uint64, b []uint64, recv []uint64) { + a = a[:len(recv)] + b = b[:len(recv)] + for i := range recv { + recv[i] = a[i] % b[i] + } +} + +func ModRecvF32(a []float32, b []float32, recv []float32) { + a = a[:len(recv)] + b = b[:len(recv)] + for i := range recv { + recv[i] = math32.Mod(a[i], b[i]) + } +} + +func ModRecvF64(a []float64, b []float64, recv []float64) { + a = a[:len(recv)] + b = b[:len(recv)] + for i := range recv { + recv[i] = math.Mod(a[i], b[i]) + } +} diff --git a/internal/execution/generic_cmp_mixed.go b/internal/execution/generic_cmp_mixed.go index b9a1154..1c53747 100644 --- a/internal/execution/generic_cmp_mixed.go +++ b/internal/execution/generic_cmp_mixed.go @@ -1,9 +1,9 @@ -// Code generated by genlib2. DO NOT EDIT. - package execution import "unsafe" +// Code generated by genlib2. DO NOT EDIT. + func GtSVI(a int, b []int, retVal []bool) { for i := range retVal { retVal[i] = a > b[i] diff --git a/internal/execution/generic_cmp_vv.go b/internal/execution/generic_cmp_vv.go index 7d528c4..a501f93 100644 --- a/internal/execution/generic_cmp_vv.go +++ b/internal/execution/generic_cmp_vv.go @@ -1,9 +1,9 @@ -// Code generated by genlib2. DO NOT EDIT. - package execution import "unsafe" +// Code generated by genlib2. DO NOT EDIT. + func GtI(a []int, b []int, retVal []bool) { a = a[:] b = b[:len(a)] diff --git a/internal/execution/generic_map.go b/internal/execution/generic_map.go index 41c7de8..f054239 100644 --- a/internal/execution/generic_map.go +++ b/internal/execution/generic_map.go @@ -1,9 +1,9 @@ -// Code generated by genlib2. DO NOT EDIT. - package execution import "unsafe" +// Code generated by genlib2. DO NOT EDIT. + func MapB(fn func(bool) bool, a []bool) { for i := range a { a[i] = fn(a[i]) diff --git a/internal/execution/generic_minmax.go b/internal/execution/generic_minmax.go index 170f01b..011645b 100644 --- a/internal/execution/generic_minmax.go +++ b/internal/execution/generic_minmax.go @@ -1,7 +1,7 @@ -// Code generated by genlib2. DO NOT EDIT. - package execution +// Code generated by genlib2. DO NOT EDIT. + func VecMinI(a, b []int) { a = a[:] b = b[:len(a)] @@ -12,6 +12,23 @@ func VecMinI(a, b []int) { } } } + +func MinSVI(a int, b []int) { + for i := range b { + if a < b[i] { + b[i] = a + } + } +} + +func MinVSI(a []int, b int) { + for i := range a { + if b < a[i] { + a[i] = b + } + } +} + func VecMaxI(a, b []int) { a = a[:] b = b[:len(a)] @@ -22,6 +39,22 @@ func VecMaxI(a, b []int) { } } } + +func MaxSVI(a int, b []int) { + for i := range b { + if a > b[i] { + b[i] = a + } + } +} + +func MaxVSI(a []int, b int) { + for i := range a { + if b > a[i] { + a[i] = b + } + } +} func VecMinI8(a, b []int8) { a = a[:] b = b[:len(a)] @@ -32,6 +65,23 @@ func VecMinI8(a, b []int8) { } } } + +func MinSVI8(a int8, b []int8) { + for i := range b { + if a < b[i] { + b[i] = a + } + } +} + +func MinVSI8(a []int8, b int8) { + for i := range a { + if b < a[i] { + a[i] = b + } + } +} + func VecMaxI8(a, b []int8) { a = a[:] b = b[:len(a)] @@ -42,6 +92,22 @@ func VecMaxI8(a, b []int8) { } } } + +func MaxSVI8(a int8, b []int8) { + for i := range b { + if a > b[i] { + b[i] = a + } + } +} + +func MaxVSI8(a []int8, b int8) { + for i := range a { + if b > a[i] { + a[i] = b + } + } +} func VecMinI16(a, b []int16) { a = a[:] b = b[:len(a)] @@ -52,6 +118,23 @@ func VecMinI16(a, b []int16) { } } } + +func MinSVI16(a int16, b []int16) { + for i := range b { + if a < b[i] { + b[i] = a + } + } +} + +func MinVSI16(a []int16, b int16) { + for i := range a { + if b < a[i] { + a[i] = b + } + } +} + func VecMaxI16(a, b []int16) { a = a[:] b = b[:len(a)] @@ -62,6 +145,22 @@ func VecMaxI16(a, b []int16) { } } } + +func MaxSVI16(a int16, b []int16) { + for i := range b { + if a > b[i] { + b[i] = a + } + } +} + +func MaxVSI16(a []int16, b int16) { + for i := range a { + if b > a[i] { + a[i] = b + } + } +} func VecMinI32(a, b []int32) { a = a[:] b = b[:len(a)] @@ -72,6 +171,23 @@ func VecMinI32(a, b []int32) { } } } + +func MinSVI32(a int32, b []int32) { + for i := range b { + if a < b[i] { + b[i] = a + } + } +} + +func MinVSI32(a []int32, b int32) { + for i := range a { + if b < a[i] { + a[i] = b + } + } +} + func VecMaxI32(a, b []int32) { a = a[:] b = b[:len(a)] @@ -82,6 +198,22 @@ func VecMaxI32(a, b []int32) { } } } + +func MaxSVI32(a int32, b []int32) { + for i := range b { + if a > b[i] { + b[i] = a + } + } +} + +func MaxVSI32(a []int32, b int32) { + for i := range a { + if b > a[i] { + a[i] = b + } + } +} func VecMinI64(a, b []int64) { a = a[:] b = b[:len(a)] @@ -92,6 +224,23 @@ func VecMinI64(a, b []int64) { } } } + +func MinSVI64(a int64, b []int64) { + for i := range b { + if a < b[i] { + b[i] = a + } + } +} + +func MinVSI64(a []int64, b int64) { + for i := range a { + if b < a[i] { + a[i] = b + } + } +} + func VecMaxI64(a, b []int64) { a = a[:] b = b[:len(a)] @@ -102,6 +251,22 @@ func VecMaxI64(a, b []int64) { } } } + +func MaxSVI64(a int64, b []int64) { + for i := range b { + if a > b[i] { + b[i] = a + } + } +} + +func MaxVSI64(a []int64, b int64) { + for i := range a { + if b > a[i] { + a[i] = b + } + } +} func VecMinU(a, b []uint) { a = a[:] b = b[:len(a)] @@ -112,6 +277,23 @@ func VecMinU(a, b []uint) { } } } + +func MinSVU(a uint, b []uint) { + for i := range b { + if a < b[i] { + b[i] = a + } + } +} + +func MinVSU(a []uint, b uint) { + for i := range a { + if b < a[i] { + a[i] = b + } + } +} + func VecMaxU(a, b []uint) { a = a[:] b = b[:len(a)] @@ -122,6 +304,22 @@ func VecMaxU(a, b []uint) { } } } + +func MaxSVU(a uint, b []uint) { + for i := range b { + if a > b[i] { + b[i] = a + } + } +} + +func MaxVSU(a []uint, b uint) { + for i := range a { + if b > a[i] { + a[i] = b + } + } +} func VecMinU8(a, b []uint8) { a = a[:] b = b[:len(a)] @@ -132,6 +330,23 @@ func VecMinU8(a, b []uint8) { } } } + +func MinSVU8(a uint8, b []uint8) { + for i := range b { + if a < b[i] { + b[i] = a + } + } +} + +func MinVSU8(a []uint8, b uint8) { + for i := range a { + if b < a[i] { + a[i] = b + } + } +} + func VecMaxU8(a, b []uint8) { a = a[:] b = b[:len(a)] @@ -142,6 +357,22 @@ func VecMaxU8(a, b []uint8) { } } } + +func MaxSVU8(a uint8, b []uint8) { + for i := range b { + if a > b[i] { + b[i] = a + } + } +} + +func MaxVSU8(a []uint8, b uint8) { + for i := range a { + if b > a[i] { + a[i] = b + } + } +} func VecMinU16(a, b []uint16) { a = a[:] b = b[:len(a)] @@ -152,6 +383,23 @@ func VecMinU16(a, b []uint16) { } } } + +func MinSVU16(a uint16, b []uint16) { + for i := range b { + if a < b[i] { + b[i] = a + } + } +} + +func MinVSU16(a []uint16, b uint16) { + for i := range a { + if b < a[i] { + a[i] = b + } + } +} + func VecMaxU16(a, b []uint16) { a = a[:] b = b[:len(a)] @@ -162,6 +410,22 @@ func VecMaxU16(a, b []uint16) { } } } + +func MaxSVU16(a uint16, b []uint16) { + for i := range b { + if a > b[i] { + b[i] = a + } + } +} + +func MaxVSU16(a []uint16, b uint16) { + for i := range a { + if b > a[i] { + a[i] = b + } + } +} func VecMinU32(a, b []uint32) { a = a[:] b = b[:len(a)] @@ -172,6 +436,23 @@ func VecMinU32(a, b []uint32) { } } } + +func MinSVU32(a uint32, b []uint32) { + for i := range b { + if a < b[i] { + b[i] = a + } + } +} + +func MinVSU32(a []uint32, b uint32) { + for i := range a { + if b < a[i] { + a[i] = b + } + } +} + func VecMaxU32(a, b []uint32) { a = a[:] b = b[:len(a)] @@ -182,6 +463,22 @@ func VecMaxU32(a, b []uint32) { } } } + +func MaxSVU32(a uint32, b []uint32) { + for i := range b { + if a > b[i] { + b[i] = a + } + } +} + +func MaxVSU32(a []uint32, b uint32) { + for i := range a { + if b > a[i] { + a[i] = b + } + } +} func VecMinU64(a, b []uint64) { a = a[:] b = b[:len(a)] @@ -192,6 +489,23 @@ func VecMinU64(a, b []uint64) { } } } + +func MinSVU64(a uint64, b []uint64) { + for i := range b { + if a < b[i] { + b[i] = a + } + } +} + +func MinVSU64(a []uint64, b uint64) { + for i := range a { + if b < a[i] { + a[i] = b + } + } +} + func VecMaxU64(a, b []uint64) { a = a[:] b = b[:len(a)] @@ -202,6 +516,22 @@ func VecMaxU64(a, b []uint64) { } } } + +func MaxSVU64(a uint64, b []uint64) { + for i := range b { + if a > b[i] { + b[i] = a + } + } +} + +func MaxVSU64(a []uint64, b uint64) { + for i := range a { + if b > a[i] { + a[i] = b + } + } +} func VecMinF32(a, b []float32) { a = a[:] b = b[:len(a)] @@ -212,6 +542,23 @@ func VecMinF32(a, b []float32) { } } } + +func MinSVF32(a float32, b []float32) { + for i := range b { + if a < b[i] { + b[i] = a + } + } +} + +func MinVSF32(a []float32, b float32) { + for i := range a { + if b < a[i] { + a[i] = b + } + } +} + func VecMaxF32(a, b []float32) { a = a[:] b = b[:len(a)] @@ -222,6 +569,22 @@ func VecMaxF32(a, b []float32) { } } } + +func MaxSVF32(a float32, b []float32) { + for i := range b { + if a > b[i] { + b[i] = a + } + } +} + +func MaxVSF32(a []float32, b float32) { + for i := range a { + if b > a[i] { + a[i] = b + } + } +} func VecMinF64(a, b []float64) { a = a[:] b = b[:len(a)] @@ -232,6 +595,23 @@ func VecMinF64(a, b []float64) { } } } + +func MinSVF64(a float64, b []float64) { + for i := range b { + if a < b[i] { + b[i] = a + } + } +} + +func MinVSF64(a []float64, b float64) { + for i := range a { + if b < a[i] { + a[i] = b + } + } +} + func VecMaxF64(a, b []float64) { a = a[:] b = b[:len(a)] @@ -242,6 +622,22 @@ func VecMaxF64(a, b []float64) { } } } + +func MaxSVF64(a float64, b []float64) { + for i := range b { + if a > b[i] { + b[i] = a + } + } +} + +func MaxVSF64(a []float64, b float64) { + for i := range a { + if b > a[i] { + a[i] = b + } + } +} func VecMinStr(a, b []string) { a = a[:] b = b[:len(a)] @@ -252,6 +648,23 @@ func VecMinStr(a, b []string) { } } } + +func MinSVStr(a string, b []string) { + for i := range b { + if a < b[i] { + b[i] = a + } + } +} + +func MinVSStr(a []string, b string) { + for i := range a { + if b < a[i] { + a[i] = b + } + } +} + func VecMaxStr(a, b []string) { a = a[:] b = b[:len(a)] @@ -262,6 +675,22 @@ func VecMaxStr(a, b []string) { } } } + +func MaxSVStr(a string, b []string) { + for i := range b { + if a > b[i] { + b[i] = a + } + } +} + +func MaxVSStr(a []string, b string) { + for i := range a { + if b > a[i] { + a[i] = b + } + } +} func MinI(a, b int) (c int) { if a < b { return a @@ -431,3 +860,1432 @@ func MaxStr(a, b string) (c string) { } return b } +func MinIterSVI(a int, b []int, bit Iterator) (err error) { + var i int + var validi bool + for { + if i, validi, err = bit.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi { + if a < b[i] { + b[i] = a + } + } + } + return +} + +func MinIterVSI(a []int, b int, ait Iterator) (err error) { + var i int + var validi bool + for { + if i, validi, err = ait.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi { + if b < a[i] { + a[i] = b + } + } + } + return +} + +func VecMinIterI(a, b []int, ait, bit Iterator) (err error) { + var i, j int + var validi, validj bool + for { + if i, validi, err = ait.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if j, validj, err = bit.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi && validj { + if b[j] < a[i] { + a[i] = b[j] + } + } + } + return +} + +func MaxIterSVI(a int, b []int, bit Iterator) (err error) { + var i int + var validi bool + for { + if i, validi, err = bit.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi { + if a > b[i] { + b[i] = a + } + } + } + return +} + +func MaxIterVSI(a []int, b int, ait Iterator) (err error) { + var i int + var validi bool + for { + if i, validi, err = ait.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi { + if b > a[i] { + a[i] = b + } + } + } + return +} + +func VecMaxIterI(a, b []int, ait, bit Iterator) (err error) { + var i, j int + var validi, validj bool + for { + if i, validi, err = ait.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if j, validj, err = bit.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi && validj { + if b[j] > a[i] { + a[i] = b[j] + } + } + } + return +} + +func MinIterSVI8(a int8, b []int8, bit Iterator) (err error) { + var i int + var validi bool + for { + if i, validi, err = bit.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi { + if a < b[i] { + b[i] = a + } + } + } + return +} + +func MinIterVSI8(a []int8, b int8, ait Iterator) (err error) { + var i int + var validi bool + for { + if i, validi, err = ait.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi { + if b < a[i] { + a[i] = b + } + } + } + return +} + +func VecMinIterI8(a, b []int8, ait, bit Iterator) (err error) { + var i, j int + var validi, validj bool + for { + if i, validi, err = ait.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if j, validj, err = bit.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi && validj { + if b[j] < a[i] { + a[i] = b[j] + } + } + } + return +} + +func MaxIterSVI8(a int8, b []int8, bit Iterator) (err error) { + var i int + var validi bool + for { + if i, validi, err = bit.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi { + if a > b[i] { + b[i] = a + } + } + } + return +} + +func MaxIterVSI8(a []int8, b int8, ait Iterator) (err error) { + var i int + var validi bool + for { + if i, validi, err = ait.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi { + if b > a[i] { + a[i] = b + } + } + } + return +} + +func VecMaxIterI8(a, b []int8, ait, bit Iterator) (err error) { + var i, j int + var validi, validj bool + for { + if i, validi, err = ait.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if j, validj, err = bit.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi && validj { + if b[j] > a[i] { + a[i] = b[j] + } + } + } + return +} + +func MinIterSVI16(a int16, b []int16, bit Iterator) (err error) { + var i int + var validi bool + for { + if i, validi, err = bit.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi { + if a < b[i] { + b[i] = a + } + } + } + return +} + +func MinIterVSI16(a []int16, b int16, ait Iterator) (err error) { + var i int + var validi bool + for { + if i, validi, err = ait.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi { + if b < a[i] { + a[i] = b + } + } + } + return +} + +func VecMinIterI16(a, b []int16, ait, bit Iterator) (err error) { + var i, j int + var validi, validj bool + for { + if i, validi, err = ait.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if j, validj, err = bit.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi && validj { + if b[j] < a[i] { + a[i] = b[j] + } + } + } + return +} + +func MaxIterSVI16(a int16, b []int16, bit Iterator) (err error) { + var i int + var validi bool + for { + if i, validi, err = bit.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi { + if a > b[i] { + b[i] = a + } + } + } + return +} + +func MaxIterVSI16(a []int16, b int16, ait Iterator) (err error) { + var i int + var validi bool + for { + if i, validi, err = ait.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi { + if b > a[i] { + a[i] = b + } + } + } + return +} + +func VecMaxIterI16(a, b []int16, ait, bit Iterator) (err error) { + var i, j int + var validi, validj bool + for { + if i, validi, err = ait.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if j, validj, err = bit.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi && validj { + if b[j] > a[i] { + a[i] = b[j] + } + } + } + return +} + +func MinIterSVI32(a int32, b []int32, bit Iterator) (err error) { + var i int + var validi bool + for { + if i, validi, err = bit.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi { + if a < b[i] { + b[i] = a + } + } + } + return +} + +func MinIterVSI32(a []int32, b int32, ait Iterator) (err error) { + var i int + var validi bool + for { + if i, validi, err = ait.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi { + if b < a[i] { + a[i] = b + } + } + } + return +} + +func VecMinIterI32(a, b []int32, ait, bit Iterator) (err error) { + var i, j int + var validi, validj bool + for { + if i, validi, err = ait.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if j, validj, err = bit.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi && validj { + if b[j] < a[i] { + a[i] = b[j] + } + } + } + return +} + +func MaxIterSVI32(a int32, b []int32, bit Iterator) (err error) { + var i int + var validi bool + for { + if i, validi, err = bit.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi { + if a > b[i] { + b[i] = a + } + } + } + return +} + +func MaxIterVSI32(a []int32, b int32, ait Iterator) (err error) { + var i int + var validi bool + for { + if i, validi, err = ait.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi { + if b > a[i] { + a[i] = b + } + } + } + return +} + +func VecMaxIterI32(a, b []int32, ait, bit Iterator) (err error) { + var i, j int + var validi, validj bool + for { + if i, validi, err = ait.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if j, validj, err = bit.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi && validj { + if b[j] > a[i] { + a[i] = b[j] + } + } + } + return +} + +func MinIterSVI64(a int64, b []int64, bit Iterator) (err error) { + var i int + var validi bool + for { + if i, validi, err = bit.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi { + if a < b[i] { + b[i] = a + } + } + } + return +} + +func MinIterVSI64(a []int64, b int64, ait Iterator) (err error) { + var i int + var validi bool + for { + if i, validi, err = ait.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi { + if b < a[i] { + a[i] = b + } + } + } + return +} + +func VecMinIterI64(a, b []int64, ait, bit Iterator) (err error) { + var i, j int + var validi, validj bool + for { + if i, validi, err = ait.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if j, validj, err = bit.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi && validj { + if b[j] < a[i] { + a[i] = b[j] + } + } + } + return +} + +func MaxIterSVI64(a int64, b []int64, bit Iterator) (err error) { + var i int + var validi bool + for { + if i, validi, err = bit.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi { + if a > b[i] { + b[i] = a + } + } + } + return +} + +func MaxIterVSI64(a []int64, b int64, ait Iterator) (err error) { + var i int + var validi bool + for { + if i, validi, err = ait.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi { + if b > a[i] { + a[i] = b + } + } + } + return +} + +func VecMaxIterI64(a, b []int64, ait, bit Iterator) (err error) { + var i, j int + var validi, validj bool + for { + if i, validi, err = ait.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if j, validj, err = bit.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi && validj { + if b[j] > a[i] { + a[i] = b[j] + } + } + } + return +} + +func MinIterSVU(a uint, b []uint, bit Iterator) (err error) { + var i int + var validi bool + for { + if i, validi, err = bit.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi { + if a < b[i] { + b[i] = a + } + } + } + return +} + +func MinIterVSU(a []uint, b uint, ait Iterator) (err error) { + var i int + var validi bool + for { + if i, validi, err = ait.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi { + if b < a[i] { + a[i] = b + } + } + } + return +} + +func VecMinIterU(a, b []uint, ait, bit Iterator) (err error) { + var i, j int + var validi, validj bool + for { + if i, validi, err = ait.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if j, validj, err = bit.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi && validj { + if b[j] < a[i] { + a[i] = b[j] + } + } + } + return +} + +func MaxIterSVU(a uint, b []uint, bit Iterator) (err error) { + var i int + var validi bool + for { + if i, validi, err = bit.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi { + if a > b[i] { + b[i] = a + } + } + } + return +} + +func MaxIterVSU(a []uint, b uint, ait Iterator) (err error) { + var i int + var validi bool + for { + if i, validi, err = ait.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi { + if b > a[i] { + a[i] = b + } + } + } + return +} + +func VecMaxIterU(a, b []uint, ait, bit Iterator) (err error) { + var i, j int + var validi, validj bool + for { + if i, validi, err = ait.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if j, validj, err = bit.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi && validj { + if b[j] > a[i] { + a[i] = b[j] + } + } + } + return +} + +func MinIterSVU8(a uint8, b []uint8, bit Iterator) (err error) { + var i int + var validi bool + for { + if i, validi, err = bit.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi { + if a < b[i] { + b[i] = a + } + } + } + return +} + +func MinIterVSU8(a []uint8, b uint8, ait Iterator) (err error) { + var i int + var validi bool + for { + if i, validi, err = ait.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi { + if b < a[i] { + a[i] = b + } + } + } + return +} + +func VecMinIterU8(a, b []uint8, ait, bit Iterator) (err error) { + var i, j int + var validi, validj bool + for { + if i, validi, err = ait.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if j, validj, err = bit.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi && validj { + if b[j] < a[i] { + a[i] = b[j] + } + } + } + return +} + +func MaxIterSVU8(a uint8, b []uint8, bit Iterator) (err error) { + var i int + var validi bool + for { + if i, validi, err = bit.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi { + if a > b[i] { + b[i] = a + } + } + } + return +} + +func MaxIterVSU8(a []uint8, b uint8, ait Iterator) (err error) { + var i int + var validi bool + for { + if i, validi, err = ait.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi { + if b > a[i] { + a[i] = b + } + } + } + return +} + +func VecMaxIterU8(a, b []uint8, ait, bit Iterator) (err error) { + var i, j int + var validi, validj bool + for { + if i, validi, err = ait.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if j, validj, err = bit.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi && validj { + if b[j] > a[i] { + a[i] = b[j] + } + } + } + return +} + +func MinIterSVU16(a uint16, b []uint16, bit Iterator) (err error) { + var i int + var validi bool + for { + if i, validi, err = bit.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi { + if a < b[i] { + b[i] = a + } + } + } + return +} + +func MinIterVSU16(a []uint16, b uint16, ait Iterator) (err error) { + var i int + var validi bool + for { + if i, validi, err = ait.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi { + if b < a[i] { + a[i] = b + } + } + } + return +} + +func VecMinIterU16(a, b []uint16, ait, bit Iterator) (err error) { + var i, j int + var validi, validj bool + for { + if i, validi, err = ait.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if j, validj, err = bit.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi && validj { + if b[j] < a[i] { + a[i] = b[j] + } + } + } + return +} + +func MaxIterSVU16(a uint16, b []uint16, bit Iterator) (err error) { + var i int + var validi bool + for { + if i, validi, err = bit.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi { + if a > b[i] { + b[i] = a + } + } + } + return +} + +func MaxIterVSU16(a []uint16, b uint16, ait Iterator) (err error) { + var i int + var validi bool + for { + if i, validi, err = ait.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi { + if b > a[i] { + a[i] = b + } + } + } + return +} + +func VecMaxIterU16(a, b []uint16, ait, bit Iterator) (err error) { + var i, j int + var validi, validj bool + for { + if i, validi, err = ait.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if j, validj, err = bit.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi && validj { + if b[j] > a[i] { + a[i] = b[j] + } + } + } + return +} + +func MinIterSVU32(a uint32, b []uint32, bit Iterator) (err error) { + var i int + var validi bool + for { + if i, validi, err = bit.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi { + if a < b[i] { + b[i] = a + } + } + } + return +} + +func MinIterVSU32(a []uint32, b uint32, ait Iterator) (err error) { + var i int + var validi bool + for { + if i, validi, err = ait.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi { + if b < a[i] { + a[i] = b + } + } + } + return +} + +func VecMinIterU32(a, b []uint32, ait, bit Iterator) (err error) { + var i, j int + var validi, validj bool + for { + if i, validi, err = ait.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if j, validj, err = bit.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi && validj { + if b[j] < a[i] { + a[i] = b[j] + } + } + } + return +} + +func MaxIterSVU32(a uint32, b []uint32, bit Iterator) (err error) { + var i int + var validi bool + for { + if i, validi, err = bit.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi { + if a > b[i] { + b[i] = a + } + } + } + return +} + +func MaxIterVSU32(a []uint32, b uint32, ait Iterator) (err error) { + var i int + var validi bool + for { + if i, validi, err = ait.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi { + if b > a[i] { + a[i] = b + } + } + } + return +} + +func VecMaxIterU32(a, b []uint32, ait, bit Iterator) (err error) { + var i, j int + var validi, validj bool + for { + if i, validi, err = ait.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if j, validj, err = bit.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi && validj { + if b[j] > a[i] { + a[i] = b[j] + } + } + } + return +} + +func MinIterSVU64(a uint64, b []uint64, bit Iterator) (err error) { + var i int + var validi bool + for { + if i, validi, err = bit.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi { + if a < b[i] { + b[i] = a + } + } + } + return +} + +func MinIterVSU64(a []uint64, b uint64, ait Iterator) (err error) { + var i int + var validi bool + for { + if i, validi, err = ait.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi { + if b < a[i] { + a[i] = b + } + } + } + return +} + +func VecMinIterU64(a, b []uint64, ait, bit Iterator) (err error) { + var i, j int + var validi, validj bool + for { + if i, validi, err = ait.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if j, validj, err = bit.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi && validj { + if b[j] < a[i] { + a[i] = b[j] + } + } + } + return +} + +func MaxIterSVU64(a uint64, b []uint64, bit Iterator) (err error) { + var i int + var validi bool + for { + if i, validi, err = bit.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi { + if a > b[i] { + b[i] = a + } + } + } + return +} + +func MaxIterVSU64(a []uint64, b uint64, ait Iterator) (err error) { + var i int + var validi bool + for { + if i, validi, err = ait.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi { + if b > a[i] { + a[i] = b + } + } + } + return +} + +func VecMaxIterU64(a, b []uint64, ait, bit Iterator) (err error) { + var i, j int + var validi, validj bool + for { + if i, validi, err = ait.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if j, validj, err = bit.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi && validj { + if b[j] > a[i] { + a[i] = b[j] + } + } + } + return +} + +func MinIterSVF32(a float32, b []float32, bit Iterator) (err error) { + var i int + var validi bool + for { + if i, validi, err = bit.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi { + if a < b[i] { + b[i] = a + } + } + } + return +} + +func MinIterVSF32(a []float32, b float32, ait Iterator) (err error) { + var i int + var validi bool + for { + if i, validi, err = ait.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi { + if b < a[i] { + a[i] = b + } + } + } + return +} + +func VecMinIterF32(a, b []float32, ait, bit Iterator) (err error) { + var i, j int + var validi, validj bool + for { + if i, validi, err = ait.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if j, validj, err = bit.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi && validj { + if b[j] < a[i] { + a[i] = b[j] + } + } + } + return +} + +func MaxIterSVF32(a float32, b []float32, bit Iterator) (err error) { + var i int + var validi bool + for { + if i, validi, err = bit.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi { + if a > b[i] { + b[i] = a + } + } + } + return +} + +func MaxIterVSF32(a []float32, b float32, ait Iterator) (err error) { + var i int + var validi bool + for { + if i, validi, err = ait.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi { + if b > a[i] { + a[i] = b + } + } + } + return +} + +func VecMaxIterF32(a, b []float32, ait, bit Iterator) (err error) { + var i, j int + var validi, validj bool + for { + if i, validi, err = ait.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if j, validj, err = bit.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi && validj { + if b[j] > a[i] { + a[i] = b[j] + } + } + } + return +} + +func MinIterSVF64(a float64, b []float64, bit Iterator) (err error) { + var i int + var validi bool + for { + if i, validi, err = bit.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi { + if a < b[i] { + b[i] = a + } + } + } + return +} + +func MinIterVSF64(a []float64, b float64, ait Iterator) (err error) { + var i int + var validi bool + for { + if i, validi, err = ait.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi { + if b < a[i] { + a[i] = b + } + } + } + return +} + +func VecMinIterF64(a, b []float64, ait, bit Iterator) (err error) { + var i, j int + var validi, validj bool + for { + if i, validi, err = ait.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if j, validj, err = bit.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi && validj { + if b[j] < a[i] { + a[i] = b[j] + } + } + } + return +} + +func MaxIterSVF64(a float64, b []float64, bit Iterator) (err error) { + var i int + var validi bool + for { + if i, validi, err = bit.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi { + if a > b[i] { + b[i] = a + } + } + } + return +} + +func MaxIterVSF64(a []float64, b float64, ait Iterator) (err error) { + var i int + var validi bool + for { + if i, validi, err = ait.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi { + if b > a[i] { + a[i] = b + } + } + } + return +} + +func VecMaxIterF64(a, b []float64, ait, bit Iterator) (err error) { + var i, j int + var validi, validj bool + for { + if i, validi, err = ait.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if j, validj, err = bit.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi && validj { + if b[j] > a[i] { + a[i] = b[j] + } + } + } + return +} + +func MinIterSVStr(a string, b []string, bit Iterator) (err error) { + var i int + var validi bool + for { + if i, validi, err = bit.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi { + if a < b[i] { + b[i] = a + } + } + } + return +} + +func MinIterVSStr(a []string, b string, ait Iterator) (err error) { + var i int + var validi bool + for { + if i, validi, err = ait.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi { + if b < a[i] { + a[i] = b + } + } + } + return +} + +func VecMinIterStr(a, b []string, ait, bit Iterator) (err error) { + var i, j int + var validi, validj bool + for { + if i, validi, err = ait.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if j, validj, err = bit.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi && validj { + if b[j] < a[i] { + a[i] = b[j] + } + } + } + return +} + +func MaxIterSVStr(a string, b []string, bit Iterator) (err error) { + var i int + var validi bool + for { + if i, validi, err = bit.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi { + if a > b[i] { + b[i] = a + } + } + } + return +} + +func MaxIterVSStr(a []string, b string, ait Iterator) (err error) { + var i int + var validi bool + for { + if i, validi, err = ait.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi { + if b > a[i] { + a[i] = b + } + } + } + return +} + +func VecMaxIterStr(a, b []string, ait, bit Iterator) (err error) { + var i, j int + var validi, validj bool + for { + if i, validi, err = ait.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if j, validj, err = bit.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi && validj { + if b[j] > a[i] { + a[i] = b[j] + } + } + } + return +} diff --git a/internal/execution/generic_reduce.go b/internal/execution/generic_reduce.go index a489f1c..ef94057 100644 --- a/internal/execution/generic_reduce.go +++ b/internal/execution/generic_reduce.go @@ -1,9 +1,9 @@ -// Code generated by genlib2. DO NOT EDIT. - package execution import "unsafe" +// Code generated by genlib2. DO NOT EDIT. + func ReduceB(f func(a, b bool) bool, def bool, l ...bool) (retVal bool) { retVal = def if len(l) == 0 { diff --git a/internal/execution/generic_unary.go b/internal/execution/generic_unary.go index cb3f87f..7c05acd 100644 --- a/internal/execution/generic_unary.go +++ b/internal/execution/generic_unary.go @@ -1,5 +1,3 @@ -// Code generated by genlib2. DO NOT EDIT. - package execution import ( @@ -9,6 +7,8 @@ import ( "github.com/chewxy/math32" ) +// Code generated by genlib2. DO NOT EDIT. + func NegI(a []int) { for i := range a { a[i] = -a[i] diff --git a/internal/execution/keepsync.go b/internal/execution/keepsync.go index 5b49f7d..8921d1c 100644 --- a/internal/execution/keepsync.go +++ b/internal/execution/keepsync.go @@ -19,18 +19,12 @@ type NoOpError interface { NoOp() bool } -type noopError struct{} - -func (e noopError) NoOp() bool { return true } -func (e noopError) Error() string { return "NoOp" } - func handleNoOp(err error) error { if err == nil { return nil } - - if _, ok := err.(NoOpError); !ok { - return err + if _, ok := err.(NoOpError); ok { + return nil } - return nil + return err } diff --git a/internal/execution/keepsync_test.go b/internal/execution/keepsync_test.go new file mode 100644 index 0000000..2e2c693 --- /dev/null +++ b/internal/execution/keepsync_test.go @@ -0,0 +1,20 @@ +package execution + +import ( + "github.com/pkg/errors" + "github.com/stretchr/testify/assert" + "testing" +) + +type noopError struct{} + +func (e noopError) NoOp() bool { return true } +func (e noopError) Error() string { return "NoOp" } + +func TestHandleNoOp(t *testing.T) { + otherErr := errors.New("other error") + + assert.Equal(t, nil, handleNoOp(noopError{})) + assert.Equal(t, nil, handleNoOp(nil)) + assert.Equal(t, otherErr, handleNoOp(otherErr)) +} diff --git a/internal/execution/reduction_specialization.go b/internal/execution/reduction_specialization.go index e83e67e..90cfe69 100644 --- a/internal/execution/reduction_specialization.go +++ b/internal/execution/reduction_specialization.go @@ -1,5 +1,3 @@ -// Code generated by genlib2. DO NOT EDIT. - package execution import ( @@ -9,6 +7,8 @@ import ( "gorgonia.org/tensor/internal/storage" ) +// Code generated by genlib2. DO NOT EDIT. + func MonotonicSum(t reflect.Type, a *storage.Header) (retVal interface{}, err error) { switch t { case Int: diff --git a/internal/serialization/README.md b/internal/serialization/README.md new file mode 100644 index 0000000..d3d8149 --- /dev/null +++ b/internal/serialization/README.md @@ -0,0 +1,33 @@ +# Serialization # + +This pseudopackage of sorts handles serialization. The "Canonical" serialized data structure is found in the `pb` subdirectory. + +# Protobuf generation + +Proteus needs to be installed, as does its dependencies. + + +1. `cd pb` +2. `rm generated*` +3. `proteus -f ../../IDLs -p gorgonia.org/tensor/internal/serialization/pb` +4. `cd ../../IDLs` +5. `find gorgonia.org/ -mindepth 2 -type f -exec mv -i '{}' . ';'` +6. `rm -rf gorgonia.org` + + +# FlatBuffer generation +1. generate protobuf first +2. delete the `import "github.com/gogo/protobuf/gogoproto/gogo.proto";` line from the generated protobuf file +3. `flatc --proto PATH/TO/generated.proto` +4. place the `generated.fbs` file in the IDLs directory +4. restore the import line in the `generated.proto` file +5. From this directory: `flatc --go-namespace fb -g PATH/TO/generated.fbs` + + +# Notes # + +`find gorgonia.org/ -mindepth 2 -type f -exec mv -i '{}' . ';'` is used to flatten and put all the stuff in the root IDLs directory. + +# The Serialization Story # + +To serialize, we copy/convert/coerce the data to the internal/serialization data structures, then call the `Marshall` methods from there \ No newline at end of file diff --git a/internal/serialization/doc.go b/internal/serialization/doc.go new file mode 100644 index 0000000..c4cb59b --- /dev/null +++ b/internal/serialization/doc.go @@ -0,0 +1,2 @@ +// package serialization provides the data structures for serialization +package serialization diff --git a/internal/serialization/fb/AP.go b/internal/serialization/fb/AP.go new file mode 100644 index 0000000..b3ca806 --- /dev/null +++ b/internal/serialization/fb/AP.go @@ -0,0 +1,110 @@ +// automatically generated by the FlatBuffers compiler, do not modify + +package fb + +import ( + flatbuffers "github.com/google/flatbuffers/go" +) + +type AP struct { + _tab flatbuffers.Table +} + +func GetRootAsAP(buf []byte, offset flatbuffers.UOffsetT) *AP { + n := flatbuffers.GetUOffsetT(buf[offset:]) + x := &AP{} + x.Init(buf, n+offset) + return x +} + +func (rcv *AP) Init(buf []byte, i flatbuffers.UOffsetT) { + rcv._tab.Bytes = buf + rcv._tab.Pos = i +} + +func (rcv *AP) Table() flatbuffers.Table { + return rcv._tab +} + +func (rcv *AP) Shape(j int) int32 { + o := flatbuffers.UOffsetT(rcv._tab.Offset(4)) + if o != 0 { + a := rcv._tab.Vector(o) + return rcv._tab.GetInt32(a + flatbuffers.UOffsetT(j*4)) + } + return 0 +} + +func (rcv *AP) ShapeLength() int { + o := flatbuffers.UOffsetT(rcv._tab.Offset(4)) + if o != 0 { + return rcv._tab.VectorLen(o) + } + return 0 +} + +func (rcv *AP) Strides(j int) int32 { + o := flatbuffers.UOffsetT(rcv._tab.Offset(6)) + if o != 0 { + a := rcv._tab.Vector(o) + return rcv._tab.GetInt32(a + flatbuffers.UOffsetT(j*4)) + } + return 0 +} + +func (rcv *AP) StridesLength() int { + o := flatbuffers.UOffsetT(rcv._tab.Offset(6)) + if o != 0 { + return rcv._tab.VectorLen(o) + } + return 0 +} + +func (rcv *AP) O() uint32 { + o := flatbuffers.UOffsetT(rcv._tab.Offset(8)) + if o != 0 { + return rcv._tab.GetUint32(o + rcv._tab.Pos) + } + return 0 +} + +func (rcv *AP) MutateO(n uint32) bool { + return rcv._tab.MutateUint32Slot(8, n) +} + +func (rcv *AP) T() int32 { + o := flatbuffers.UOffsetT(rcv._tab.Offset(10)) + if o != 0 { + return rcv._tab.GetInt32(o + rcv._tab.Pos) + } + return 0 +} + +func (rcv *AP) MutateT(n int32) bool { + return rcv._tab.MutateInt32Slot(10, n) +} + +func APStart(builder *flatbuffers.Builder) { + builder.StartObject(4) +} +func APAddShape(builder *flatbuffers.Builder, shape flatbuffers.UOffsetT) { + builder.PrependUOffsetTSlot(0, flatbuffers.UOffsetT(shape), 0) +} +func APStartShapeVector(builder *flatbuffers.Builder, numElems int) flatbuffers.UOffsetT { + return builder.StartVector(4, numElems, 4) +} +func APAddStrides(builder *flatbuffers.Builder, strides flatbuffers.UOffsetT) { + builder.PrependUOffsetTSlot(1, flatbuffers.UOffsetT(strides), 0) +} +func APStartStridesVector(builder *flatbuffers.Builder, numElems int) flatbuffers.UOffsetT { + return builder.StartVector(4, numElems, 4) +} +func APAddO(builder *flatbuffers.Builder, o uint32) { + builder.PrependUint32Slot(2, o, 0) +} +func APAddT(builder *flatbuffers.Builder, t int32) { + builder.PrependInt32Slot(3, t, 0) +} +func APEnd(builder *flatbuffers.Builder) flatbuffers.UOffsetT { + return builder.EndObject() +} diff --git a/internal/serialization/fb/Dense.go b/internal/serialization/fb/Dense.go new file mode 100644 index 0000000..2a961ee --- /dev/null +++ b/internal/serialization/fb/Dense.go @@ -0,0 +1,152 @@ +// automatically generated by the FlatBuffers compiler, do not modify + +package fb + +import ( + flatbuffers "github.com/google/flatbuffers/go" +) + +type Dense struct { + _tab flatbuffers.Table +} + +func GetRootAsDense(buf []byte, offset flatbuffers.UOffsetT) *Dense { + n := flatbuffers.GetUOffsetT(buf[offset:]) + x := &Dense{} + x.Init(buf, n+offset) + return x +} + +func (rcv *Dense) Init(buf []byte, i flatbuffers.UOffsetT) { + rcv._tab.Bytes = buf + rcv._tab.Pos = i +} + +func (rcv *Dense) Table() flatbuffers.Table { + return rcv._tab +} + +func (rcv *Dense) Shape(j int) int32 { + o := flatbuffers.UOffsetT(rcv._tab.Offset(4)) + if o != 0 { + a := rcv._tab.Vector(o) + return rcv._tab.GetInt32(a + flatbuffers.UOffsetT(j*4)) + } + return 0 +} + +func (rcv *Dense) ShapeLength() int { + o := flatbuffers.UOffsetT(rcv._tab.Offset(4)) + if o != 0 { + return rcv._tab.VectorLen(o) + } + return 0 +} + +func (rcv *Dense) Strides(j int) int32 { + o := flatbuffers.UOffsetT(rcv._tab.Offset(6)) + if o != 0 { + a := rcv._tab.Vector(o) + return rcv._tab.GetInt32(a + flatbuffers.UOffsetT(j*4)) + } + return 0 +} + +func (rcv *Dense) StridesLength() int { + o := flatbuffers.UOffsetT(rcv._tab.Offset(6)) + if o != 0 { + return rcv._tab.VectorLen(o) + } + return 0 +} + +func (rcv *Dense) O() uint32 { + o := flatbuffers.UOffsetT(rcv._tab.Offset(8)) + if o != 0 { + return rcv._tab.GetUint32(o + rcv._tab.Pos) + } + return 0 +} + +func (rcv *Dense) MutateO(n uint32) bool { + return rcv._tab.MutateUint32Slot(8, n) +} + +func (rcv *Dense) T() int32 { + o := flatbuffers.UOffsetT(rcv._tab.Offset(10)) + if o != 0 { + return rcv._tab.GetInt32(o + rcv._tab.Pos) + } + return 0 +} + +func (rcv *Dense) MutateT(n int32) bool { + return rcv._tab.MutateInt32Slot(10, n) +} + +func (rcv *Dense) Type() []byte { + o := flatbuffers.UOffsetT(rcv._tab.Offset(12)) + if o != 0 { + return rcv._tab.ByteVector(o + rcv._tab.Pos) + } + return nil +} + +func (rcv *Dense) Data(j int) byte { + o := flatbuffers.UOffsetT(rcv._tab.Offset(14)) + if o != 0 { + a := rcv._tab.Vector(o) + return rcv._tab.GetByte(a + flatbuffers.UOffsetT(j*1)) + } + return 0 +} + +func (rcv *Dense) DataLength() int { + o := flatbuffers.UOffsetT(rcv._tab.Offset(14)) + if o != 0 { + return rcv._tab.VectorLen(o) + } + return 0 +} + +func (rcv *Dense) DataBytes() []byte { + o := flatbuffers.UOffsetT(rcv._tab.Offset(14)) + if o != 0 { + return rcv._tab.ByteVector(o + rcv._tab.Pos) + } + return nil +} + +func DenseStart(builder *flatbuffers.Builder) { + builder.StartObject(6) +} +func DenseAddShape(builder *flatbuffers.Builder, shape flatbuffers.UOffsetT) { + builder.PrependUOffsetTSlot(0, flatbuffers.UOffsetT(shape), 0) +} +func DenseStartShapeVector(builder *flatbuffers.Builder, numElems int) flatbuffers.UOffsetT { + return builder.StartVector(4, numElems, 4) +} +func DenseAddStrides(builder *flatbuffers.Builder, strides flatbuffers.UOffsetT) { + builder.PrependUOffsetTSlot(1, flatbuffers.UOffsetT(strides), 0) +} +func DenseStartStridesVector(builder *flatbuffers.Builder, numElems int) flatbuffers.UOffsetT { + return builder.StartVector(4, numElems, 4) +} +func DenseAddO(builder *flatbuffers.Builder, o uint32) { + builder.PrependUint32Slot(2, o, 0) +} +func DenseAddT(builder *flatbuffers.Builder, t int32) { + builder.PrependInt32Slot(3, t, 0) +} +func DenseAddType(builder *flatbuffers.Builder, type_ flatbuffers.UOffsetT) { + builder.PrependUOffsetTSlot(4, flatbuffers.UOffsetT(type_), 0) +} +func DenseAddData(builder *flatbuffers.Builder, data flatbuffers.UOffsetT) { + builder.PrependUOffsetTSlot(5, flatbuffers.UOffsetT(data), 0) +} +func DenseStartDataVector(builder *flatbuffers.Builder, numElems int) flatbuffers.UOffsetT { + return builder.StartVector(1, numElems, 1) +} +func DenseEnd(builder *flatbuffers.Builder) flatbuffers.UOffsetT { + return builder.EndObject() +} diff --git a/internal/serialization/fb/MaskedDense.go b/internal/serialization/fb/MaskedDense.go new file mode 100644 index 0000000..271e77e --- /dev/null +++ b/internal/serialization/fb/MaskedDense.go @@ -0,0 +1,198 @@ +// automatically generated by the FlatBuffers compiler, do not modify + +package fb + +import ( + flatbuffers "github.com/google/flatbuffers/go" +) + +type MaskedDense struct { + _tab flatbuffers.Table +} + +func GetRootAsMaskedDense(buf []byte, offset flatbuffers.UOffsetT) *MaskedDense { + n := flatbuffers.GetUOffsetT(buf[offset:]) + x := &MaskedDense{} + x.Init(buf, n+offset) + return x +} + +func (rcv *MaskedDense) Init(buf []byte, i flatbuffers.UOffsetT) { + rcv._tab.Bytes = buf + rcv._tab.Pos = i +} + +func (rcv *MaskedDense) Table() flatbuffers.Table { + return rcv._tab +} + +func (rcv *MaskedDense) Shape(j int) int32 { + o := flatbuffers.UOffsetT(rcv._tab.Offset(4)) + if o != 0 { + a := rcv._tab.Vector(o) + return rcv._tab.GetInt32(a + flatbuffers.UOffsetT(j*4)) + } + return 0 +} + +func (rcv *MaskedDense) ShapeLength() int { + o := flatbuffers.UOffsetT(rcv._tab.Offset(4)) + if o != 0 { + return rcv._tab.VectorLen(o) + } + return 0 +} + +func (rcv *MaskedDense) Strides(j int) int32 { + o := flatbuffers.UOffsetT(rcv._tab.Offset(6)) + if o != 0 { + a := rcv._tab.Vector(o) + return rcv._tab.GetInt32(a + flatbuffers.UOffsetT(j*4)) + } + return 0 +} + +func (rcv *MaskedDense) StridesLength() int { + o := flatbuffers.UOffsetT(rcv._tab.Offset(6)) + if o != 0 { + return rcv._tab.VectorLen(o) + } + return 0 +} + +func (rcv *MaskedDense) O() uint32 { + o := flatbuffers.UOffsetT(rcv._tab.Offset(8)) + if o != 0 { + return rcv._tab.GetUint32(o + rcv._tab.Pos) + } + return 0 +} + +func (rcv *MaskedDense) MutateO(n uint32) bool { + return rcv._tab.MutateUint32Slot(8, n) +} + +func (rcv *MaskedDense) T() int32 { + o := flatbuffers.UOffsetT(rcv._tab.Offset(10)) + if o != 0 { + return rcv._tab.GetInt32(o + rcv._tab.Pos) + } + return 0 +} + +func (rcv *MaskedDense) MutateT(n int32) bool { + return rcv._tab.MutateInt32Slot(10, n) +} + +func (rcv *MaskedDense) Type() []byte { + o := flatbuffers.UOffsetT(rcv._tab.Offset(12)) + if o != 0 { + return rcv._tab.ByteVector(o + rcv._tab.Pos) + } + return nil +} + +func (rcv *MaskedDense) Data(j int) byte { + o := flatbuffers.UOffsetT(rcv._tab.Offset(14)) + if o != 0 { + a := rcv._tab.Vector(o) + return rcv._tab.GetByte(a + flatbuffers.UOffsetT(j*1)) + } + return 0 +} + +func (rcv *MaskedDense) DataLength() int { + o := flatbuffers.UOffsetT(rcv._tab.Offset(14)) + if o != 0 { + return rcv._tab.VectorLen(o) + } + return 0 +} + +func (rcv *MaskedDense) DataBytes() []byte { + o := flatbuffers.UOffsetT(rcv._tab.Offset(14)) + if o != 0 { + return rcv._tab.ByteVector(o + rcv._tab.Pos) + } + return nil +} + +func (rcv *MaskedDense) Mask(j int) byte { + o := flatbuffers.UOffsetT(rcv._tab.Offset(16)) + if o != 0 { + a := rcv._tab.Vector(o) + return rcv._tab.GetByte(a + flatbuffers.UOffsetT(j*1)) + } + return 0 +} + +func (rcv *MaskedDense) MaskLength() int { + o := flatbuffers.UOffsetT(rcv._tab.Offset(16)) + if o != 0 { + return rcv._tab.VectorLen(o) + } + return 0 +} + +func (rcv *MaskedDense) MaskIsSoft(j int) byte { + o := flatbuffers.UOffsetT(rcv._tab.Offset(18)) + if o != 0 { + a := rcv._tab.Vector(o) + return rcv._tab.GetByte(a + flatbuffers.UOffsetT(j*1)) + } + return 0 +} + +func (rcv *MaskedDense) MaskIsSoftLength() int { + o := flatbuffers.UOffsetT(rcv._tab.Offset(18)) + if o != 0 { + return rcv._tab.VectorLen(o) + } + return 0 +} + +func MaskedDenseStart(builder *flatbuffers.Builder) { + builder.StartObject(8) +} +func MaskedDenseAddShape(builder *flatbuffers.Builder, shape flatbuffers.UOffsetT) { + builder.PrependUOffsetTSlot(0, flatbuffers.UOffsetT(shape), 0) +} +func MaskedDenseStartShapeVector(builder *flatbuffers.Builder, numElems int) flatbuffers.UOffsetT { + return builder.StartVector(4, numElems, 4) +} +func MaskedDenseAddStrides(builder *flatbuffers.Builder, strides flatbuffers.UOffsetT) { + builder.PrependUOffsetTSlot(1, flatbuffers.UOffsetT(strides), 0) +} +func MaskedDenseStartStridesVector(builder *flatbuffers.Builder, numElems int) flatbuffers.UOffsetT { + return builder.StartVector(4, numElems, 4) +} +func MaskedDenseAddO(builder *flatbuffers.Builder, o uint32) { + builder.PrependUint32Slot(2, o, 0) +} +func MaskedDenseAddT(builder *flatbuffers.Builder, t int32) { + builder.PrependInt32Slot(3, t, 0) +} +func MaskedDenseAddType(builder *flatbuffers.Builder, type_ flatbuffers.UOffsetT) { + builder.PrependUOffsetTSlot(4, flatbuffers.UOffsetT(type_), 0) +} +func MaskedDenseAddData(builder *flatbuffers.Builder, data flatbuffers.UOffsetT) { + builder.PrependUOffsetTSlot(5, flatbuffers.UOffsetT(data), 0) +} +func MaskedDenseStartDataVector(builder *flatbuffers.Builder, numElems int) flatbuffers.UOffsetT { + return builder.StartVector(1, numElems, 1) +} +func MaskedDenseAddMask(builder *flatbuffers.Builder, mask flatbuffers.UOffsetT) { + builder.PrependUOffsetTSlot(6, flatbuffers.UOffsetT(mask), 0) +} +func MaskedDenseStartMaskVector(builder *flatbuffers.Builder, numElems int) flatbuffers.UOffsetT { + return builder.StartVector(1, numElems, 1) +} +func MaskedDenseAddMaskIsSoft(builder *flatbuffers.Builder, maskIsSoft flatbuffers.UOffsetT) { + builder.PrependUOffsetTSlot(7, flatbuffers.UOffsetT(maskIsSoft), 0) +} +func MaskedDenseStartMaskIsSoftVector(builder *flatbuffers.Builder, numElems int) flatbuffers.UOffsetT { + return builder.StartVector(1, numElems, 1) +} +func MaskedDenseEnd(builder *flatbuffers.Builder) flatbuffers.UOffsetT { + return builder.EndObject() +} diff --git a/internal/serialization/fb/Triangle.go b/internal/serialization/fb/Triangle.go new file mode 100644 index 0000000..599a06b --- /dev/null +++ b/internal/serialization/fb/Triangle.go @@ -0,0 +1,18 @@ +// automatically generated by the FlatBuffers compiler, do not modify + +package fb + +const ( + TriangleNOT_TRIANGLE = 0 + TriangleUPPER = 1 + TriangleLOWER = 2 + TriangleSYMMETRIC = 3 +) + +var EnumNamesTriangle = map[int]string{ + TriangleNOT_TRIANGLE:"NOT_TRIANGLE", + TriangleUPPER:"UPPER", + TriangleLOWER:"LOWER", + TriangleSYMMETRIC:"SYMMETRIC", +} + diff --git a/internal/serialization/pb/dense.go b/internal/serialization/pb/dense.go new file mode 100644 index 0000000..950c3ff --- /dev/null +++ b/internal/serialization/pb/dense.go @@ -0,0 +1,45 @@ +package pb + +//proteus:generate +type DataOrder byte + +// the reason for spreading the states out is because proteaus cannot handle non-iota tates +const ( + RowMajorContiguous = iota + RowMajorNonContiguous + ColMajorContiguous + ColMajorNonContiguous +) + +//proteus:generate +type Triangle byte + +const ( + NotTriangle Triangle = iota + Upper + Lower + Symmetric +) + +//proteus:generate +type AP struct { + Shape []int32 + Strides []int32 + + O DataOrder + T Triangle +} + +//proteus:generate +type Dense struct { + AP + Type string // type name + Data []byte +} + +//proteus:generate +type MaskedDense struct { + Dense + Mask []bool + MaskIsSoft []bool +} diff --git a/internal/serialization/pb/generated.pb.go b/internal/serialization/pb/generated.pb.go new file mode 100644 index 0000000..831ce90 --- /dev/null +++ b/internal/serialization/pb/generated.pb.go @@ -0,0 +1,1457 @@ +// Code generated by protoc-gen-gogo. DO NOT EDIT. +// source: gorgonia.org/tensor/internal/serialization/pb/generated.proto + +/* + Package pb is a generated protocol buffer package. + + It is generated from these files: + gorgonia.org/tensor/internal/serialization/pb/generated.proto + + It has these top-level messages: + AP + Dense + MaskedDense +*/ +package pb + +import proto "github.com/golang/protobuf/proto" +import fmt "fmt" +import math "math" +import _ "github.com/gogo/protobuf/gogoproto" + +import io "io" + +// Reference imports to suppress errors if they are not otherwise used. +var _ = proto.Marshal +var _ = fmt.Errorf +var _ = math.Inf + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the proto package it is being compiled against. +// A compilation error at this line likely means your copy of the +// proto package needs to be updated. +const _ = proto.ProtoPackageIsVersion2 // please upgrade the proto package + +var Triangle_name = map[int32]string{ + 0: "NOT_TRIANGLE", + 1: "UPPER", + 2: "LOWER", + 3: "SYMMETRIC", +} +var Triangle_value = map[string]int32{ + "NOT_TRIANGLE": 0, + "UPPER": 1, + "LOWER": 2, + "SYMMETRIC": 3, +} + +func (Triangle) EnumDescriptor() ([]byte, []int) { return fileDescriptorGenerated, []int{0} } + +func (m *AP) Reset() { *m = AP{} } +func (m *AP) String() string { return proto.CompactTextString(m) } +func (*AP) ProtoMessage() {} +func (*AP) Descriptor() ([]byte, []int) { return fileDescriptorGenerated, []int{0} } + +func (m *Dense) Reset() { *m = Dense{} } +func (m *Dense) String() string { return proto.CompactTextString(m) } +func (*Dense) ProtoMessage() {} +func (*Dense) Descriptor() ([]byte, []int) { return fileDescriptorGenerated, []int{1} } + +func (m *MaskedDense) Reset() { *m = MaskedDense{} } +func (m *MaskedDense) String() string { return proto.CompactTextString(m) } +func (*MaskedDense) ProtoMessage() {} +func (*MaskedDense) Descriptor() ([]byte, []int) { return fileDescriptorGenerated, []int{2} } + +func init() { + proto.RegisterType((*AP)(nil), "gorgonia.org.tensor.internal.serialization.pb.AP") + proto.RegisterType((*Dense)(nil), "gorgonia.org.tensor.internal.serialization.pb.Dense") + proto.RegisterType((*MaskedDense)(nil), "gorgonia.org.tensor.internal.serialization.pb.MaskedDense") + proto.RegisterEnum("gorgonia.org.tensor.internal.serialization.pb.Triangle", Triangle_name, Triangle_value) +} +func (m *AP) Marshal() (dAtA []byte, err error) { + size := m.ProtoSize() + dAtA = make([]byte, size) + n, err := m.MarshalTo(dAtA) + if err != nil { + return nil, err + } + return dAtA[:n], nil +} + +func (m *AP) MarshalTo(dAtA []byte) (int, error) { + var i int + _ = i + var l int + _ = l + if len(m.Shape) > 0 { + dAtA2 := make([]byte, len(m.Shape)*10) + var j1 int + for _, num1 := range m.Shape { + num := uint64(num1) + for num >= 1<<7 { + dAtA2[j1] = uint8(uint64(num)&0x7f | 0x80) + num >>= 7 + j1++ + } + dAtA2[j1] = uint8(num) + j1++ + } + dAtA[i] = 0xa + i++ + i = encodeVarintGenerated(dAtA, i, uint64(j1)) + i += copy(dAtA[i:], dAtA2[:j1]) + } + if len(m.Strides) > 0 { + dAtA4 := make([]byte, len(m.Strides)*10) + var j3 int + for _, num1 := range m.Strides { + num := uint64(num1) + for num >= 1<<7 { + dAtA4[j3] = uint8(uint64(num)&0x7f | 0x80) + num >>= 7 + j3++ + } + dAtA4[j3] = uint8(num) + j3++ + } + dAtA[i] = 0x12 + i++ + i = encodeVarintGenerated(dAtA, i, uint64(j3)) + i += copy(dAtA[i:], dAtA4[:j3]) + } + if m.O != 0 { + dAtA[i] = 0x18 + i++ + i = encodeVarintGenerated(dAtA, i, uint64(m.O)) + } + if m.T != 0 { + dAtA[i] = 0x20 + i++ + i = encodeVarintGenerated(dAtA, i, uint64(m.T)) + } + return i, nil +} + +func (m *Dense) Marshal() (dAtA []byte, err error) { + size := m.ProtoSize() + dAtA = make([]byte, size) + n, err := m.MarshalTo(dAtA) + if err != nil { + return nil, err + } + return dAtA[:n], nil +} + +func (m *Dense) MarshalTo(dAtA []byte) (int, error) { + var i int + _ = i + var l int + _ = l + if len(m.Shape) > 0 { + dAtA6 := make([]byte, len(m.Shape)*10) + var j5 int + for _, num1 := range m.Shape { + num := uint64(num1) + for num >= 1<<7 { + dAtA6[j5] = uint8(uint64(num)&0x7f | 0x80) + num >>= 7 + j5++ + } + dAtA6[j5] = uint8(num) + j5++ + } + dAtA[i] = 0xa + i++ + i = encodeVarintGenerated(dAtA, i, uint64(j5)) + i += copy(dAtA[i:], dAtA6[:j5]) + } + if len(m.Strides) > 0 { + dAtA8 := make([]byte, len(m.Strides)*10) + var j7 int + for _, num1 := range m.Strides { + num := uint64(num1) + for num >= 1<<7 { + dAtA8[j7] = uint8(uint64(num)&0x7f | 0x80) + num >>= 7 + j7++ + } + dAtA8[j7] = uint8(num) + j7++ + } + dAtA[i] = 0x12 + i++ + i = encodeVarintGenerated(dAtA, i, uint64(j7)) + i += copy(dAtA[i:], dAtA8[:j7]) + } + if m.O != 0 { + dAtA[i] = 0x18 + i++ + i = encodeVarintGenerated(dAtA, i, uint64(m.O)) + } + if m.T != 0 { + dAtA[i] = 0x20 + i++ + i = encodeVarintGenerated(dAtA, i, uint64(m.T)) + } + if len(m.Type) > 0 { + dAtA[i] = 0x2a + i++ + i = encodeVarintGenerated(dAtA, i, uint64(len(m.Type))) + i += copy(dAtA[i:], m.Type) + } + if len(m.Data) > 0 { + dAtA[i] = 0x32 + i++ + i = encodeVarintGenerated(dAtA, i, uint64(len(m.Data))) + i += copy(dAtA[i:], m.Data) + } + return i, nil +} + +func (m *MaskedDense) Marshal() (dAtA []byte, err error) { + size := m.ProtoSize() + dAtA = make([]byte, size) + n, err := m.MarshalTo(dAtA) + if err != nil { + return nil, err + } + return dAtA[:n], nil +} + +func (m *MaskedDense) MarshalTo(dAtA []byte) (int, error) { + var i int + _ = i + var l int + _ = l + if len(m.Shape) > 0 { + dAtA10 := make([]byte, len(m.Shape)*10) + var j9 int + for _, num1 := range m.Shape { + num := uint64(num1) + for num >= 1<<7 { + dAtA10[j9] = uint8(uint64(num)&0x7f | 0x80) + num >>= 7 + j9++ + } + dAtA10[j9] = uint8(num) + j9++ + } + dAtA[i] = 0xa + i++ + i = encodeVarintGenerated(dAtA, i, uint64(j9)) + i += copy(dAtA[i:], dAtA10[:j9]) + } + if len(m.Strides) > 0 { + dAtA12 := make([]byte, len(m.Strides)*10) + var j11 int + for _, num1 := range m.Strides { + num := uint64(num1) + for num >= 1<<7 { + dAtA12[j11] = uint8(uint64(num)&0x7f | 0x80) + num >>= 7 + j11++ + } + dAtA12[j11] = uint8(num) + j11++ + } + dAtA[i] = 0x12 + i++ + i = encodeVarintGenerated(dAtA, i, uint64(j11)) + i += copy(dAtA[i:], dAtA12[:j11]) + } + if m.O != 0 { + dAtA[i] = 0x18 + i++ + i = encodeVarintGenerated(dAtA, i, uint64(m.O)) + } + if m.T != 0 { + dAtA[i] = 0x20 + i++ + i = encodeVarintGenerated(dAtA, i, uint64(m.T)) + } + if len(m.Type) > 0 { + dAtA[i] = 0x2a + i++ + i = encodeVarintGenerated(dAtA, i, uint64(len(m.Type))) + i += copy(dAtA[i:], m.Type) + } + if len(m.Data) > 0 { + dAtA[i] = 0x32 + i++ + i = encodeVarintGenerated(dAtA, i, uint64(len(m.Data))) + i += copy(dAtA[i:], m.Data) + } + if len(m.Mask) > 0 { + dAtA[i] = 0x3a + i++ + i = encodeVarintGenerated(dAtA, i, uint64(len(m.Mask))) + for _, b := range m.Mask { + if b { + dAtA[i] = 1 + } else { + dAtA[i] = 0 + } + i++ + } + } + if len(m.MaskIsSoft) > 0 { + dAtA[i] = 0x42 + i++ + i = encodeVarintGenerated(dAtA, i, uint64(len(m.MaskIsSoft))) + for _, b := range m.MaskIsSoft { + if b { + dAtA[i] = 1 + } else { + dAtA[i] = 0 + } + i++ + } + } + return i, nil +} + +func encodeVarintGenerated(dAtA []byte, offset int, v uint64) int { + for v >= 1<<7 { + dAtA[offset] = uint8(v&0x7f | 0x80) + v >>= 7 + offset++ + } + dAtA[offset] = uint8(v) + return offset + 1 +} +func (m *AP) ProtoSize() (n int) { + var l int + _ = l + if len(m.Shape) > 0 { + l = 0 + for _, e := range m.Shape { + l += sovGenerated(uint64(e)) + } + n += 1 + sovGenerated(uint64(l)) + l + } + if len(m.Strides) > 0 { + l = 0 + for _, e := range m.Strides { + l += sovGenerated(uint64(e)) + } + n += 1 + sovGenerated(uint64(l)) + l + } + if m.O != 0 { + n += 1 + sovGenerated(uint64(m.O)) + } + if m.T != 0 { + n += 1 + sovGenerated(uint64(m.T)) + } + return n +} + +func (m *Dense) ProtoSize() (n int) { + var l int + _ = l + if len(m.Shape) > 0 { + l = 0 + for _, e := range m.Shape { + l += sovGenerated(uint64(e)) + } + n += 1 + sovGenerated(uint64(l)) + l + } + if len(m.Strides) > 0 { + l = 0 + for _, e := range m.Strides { + l += sovGenerated(uint64(e)) + } + n += 1 + sovGenerated(uint64(l)) + l + } + if m.O != 0 { + n += 1 + sovGenerated(uint64(m.O)) + } + if m.T != 0 { + n += 1 + sovGenerated(uint64(m.T)) + } + l = len(m.Type) + if l > 0 { + n += 1 + l + sovGenerated(uint64(l)) + } + l = len(m.Data) + if l > 0 { + n += 1 + l + sovGenerated(uint64(l)) + } + return n +} + +func (m *MaskedDense) ProtoSize() (n int) { + var l int + _ = l + if len(m.Shape) > 0 { + l = 0 + for _, e := range m.Shape { + l += sovGenerated(uint64(e)) + } + n += 1 + sovGenerated(uint64(l)) + l + } + if len(m.Strides) > 0 { + l = 0 + for _, e := range m.Strides { + l += sovGenerated(uint64(e)) + } + n += 1 + sovGenerated(uint64(l)) + l + } + if m.O != 0 { + n += 1 + sovGenerated(uint64(m.O)) + } + if m.T != 0 { + n += 1 + sovGenerated(uint64(m.T)) + } + l = len(m.Type) + if l > 0 { + n += 1 + l + sovGenerated(uint64(l)) + } + l = len(m.Data) + if l > 0 { + n += 1 + l + sovGenerated(uint64(l)) + } + if len(m.Mask) > 0 { + n += 1 + sovGenerated(uint64(len(m.Mask))) + len(m.Mask)*1 + } + if len(m.MaskIsSoft) > 0 { + n += 1 + sovGenerated(uint64(len(m.MaskIsSoft))) + len(m.MaskIsSoft)*1 + } + return n +} + +func sovGenerated(x uint64) (n int) { + for { + n++ + x >>= 7 + if x == 0 { + break + } + } + return n +} +func sozGenerated(x uint64) (n int) { + return sovGenerated(uint64((x << 1) ^ uint64((int64(x) >> 63)))) +} +func (m *AP) Unmarshal(dAtA []byte) error { + l := len(dAtA) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowGenerated + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: AP: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: AP: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + case 1: + if wireType == 0 { + var v int32 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowGenerated + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + v |= (int32(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + m.Shape = append(m.Shape, v) + } else if wireType == 2 { + var packedLen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowGenerated + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + packedLen |= (int(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + if packedLen < 0 { + return ErrInvalidLengthGenerated + } + postIndex := iNdEx + packedLen + if postIndex > l { + return io.ErrUnexpectedEOF + } + for iNdEx < postIndex { + var v int32 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowGenerated + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + v |= (int32(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + m.Shape = append(m.Shape, v) + } + } else { + return fmt.Errorf("proto: wrong wireType = %d for field Shape", wireType) + } + case 2: + if wireType == 0 { + var v int32 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowGenerated + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + v |= (int32(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + m.Strides = append(m.Strides, v) + } else if wireType == 2 { + var packedLen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowGenerated + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + packedLen |= (int(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + if packedLen < 0 { + return ErrInvalidLengthGenerated + } + postIndex := iNdEx + packedLen + if postIndex > l { + return io.ErrUnexpectedEOF + } + for iNdEx < postIndex { + var v int32 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowGenerated + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + v |= (int32(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + m.Strides = append(m.Strides, v) + } + } else { + return fmt.Errorf("proto: wrong wireType = %d for field Strides", wireType) + } + case 3: + if wireType != 0 { + return fmt.Errorf("proto: wrong wireType = %d for field O", wireType) + } + m.O = 0 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowGenerated + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + m.O |= (DataOrder(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + case 4: + if wireType != 0 { + return fmt.Errorf("proto: wrong wireType = %d for field T", wireType) + } + m.T = 0 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowGenerated + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + m.T |= (Triangle(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + default: + iNdEx = preIndex + skippy, err := skipGenerated(dAtA[iNdEx:]) + if err != nil { + return err + } + if skippy < 0 { + return ErrInvalidLengthGenerated + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + iNdEx += skippy + } + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} +func (m *Dense) Unmarshal(dAtA []byte) error { + l := len(dAtA) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowGenerated + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: Dense: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: Dense: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + case 1: + if wireType == 0 { + var v int32 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowGenerated + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + v |= (int32(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + m.Shape = append(m.Shape, v) + } else if wireType == 2 { + var packedLen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowGenerated + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + packedLen |= (int(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + if packedLen < 0 { + return ErrInvalidLengthGenerated + } + postIndex := iNdEx + packedLen + if postIndex > l { + return io.ErrUnexpectedEOF + } + for iNdEx < postIndex { + var v int32 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowGenerated + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + v |= (int32(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + m.Shape = append(m.Shape, v) + } + } else { + return fmt.Errorf("proto: wrong wireType = %d for field Shape", wireType) + } + case 2: + if wireType == 0 { + var v int32 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowGenerated + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + v |= (int32(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + m.Strides = append(m.Strides, v) + } else if wireType == 2 { + var packedLen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowGenerated + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + packedLen |= (int(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + if packedLen < 0 { + return ErrInvalidLengthGenerated + } + postIndex := iNdEx + packedLen + if postIndex > l { + return io.ErrUnexpectedEOF + } + for iNdEx < postIndex { + var v int32 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowGenerated + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + v |= (int32(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + m.Strides = append(m.Strides, v) + } + } else { + return fmt.Errorf("proto: wrong wireType = %d for field Strides", wireType) + } + case 3: + if wireType != 0 { + return fmt.Errorf("proto: wrong wireType = %d for field O", wireType) + } + m.O = 0 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowGenerated + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + m.O |= (DataOrder(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + case 4: + if wireType != 0 { + return fmt.Errorf("proto: wrong wireType = %d for field T", wireType) + } + m.T = 0 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowGenerated + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + m.T |= (Triangle(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + case 5: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Type", wireType) + } + var stringLen uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowGenerated + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + stringLen |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + intStringLen := int(stringLen) + if intStringLen < 0 { + return ErrInvalidLengthGenerated + } + postIndex := iNdEx + intStringLen + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.Type = string(dAtA[iNdEx:postIndex]) + iNdEx = postIndex + case 6: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Data", wireType) + } + var byteLen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowGenerated + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + byteLen |= (int(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + if byteLen < 0 { + return ErrInvalidLengthGenerated + } + postIndex := iNdEx + byteLen + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.Data = append(m.Data[:0], dAtA[iNdEx:postIndex]...) + if m.Data == nil { + m.Data = []byte{} + } + iNdEx = postIndex + default: + iNdEx = preIndex + skippy, err := skipGenerated(dAtA[iNdEx:]) + if err != nil { + return err + } + if skippy < 0 { + return ErrInvalidLengthGenerated + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + iNdEx += skippy + } + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} +func (m *MaskedDense) Unmarshal(dAtA []byte) error { + l := len(dAtA) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowGenerated + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: MaskedDense: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: MaskedDense: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + case 1: + if wireType == 0 { + var v int32 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowGenerated + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + v |= (int32(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + m.Shape = append(m.Shape, v) + } else if wireType == 2 { + var packedLen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowGenerated + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + packedLen |= (int(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + if packedLen < 0 { + return ErrInvalidLengthGenerated + } + postIndex := iNdEx + packedLen + if postIndex > l { + return io.ErrUnexpectedEOF + } + for iNdEx < postIndex { + var v int32 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowGenerated + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + v |= (int32(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + m.Shape = append(m.Shape, v) + } + } else { + return fmt.Errorf("proto: wrong wireType = %d for field Shape", wireType) + } + case 2: + if wireType == 0 { + var v int32 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowGenerated + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + v |= (int32(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + m.Strides = append(m.Strides, v) + } else if wireType == 2 { + var packedLen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowGenerated + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + packedLen |= (int(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + if packedLen < 0 { + return ErrInvalidLengthGenerated + } + postIndex := iNdEx + packedLen + if postIndex > l { + return io.ErrUnexpectedEOF + } + for iNdEx < postIndex { + var v int32 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowGenerated + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + v |= (int32(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + m.Strides = append(m.Strides, v) + } + } else { + return fmt.Errorf("proto: wrong wireType = %d for field Strides", wireType) + } + case 3: + if wireType != 0 { + return fmt.Errorf("proto: wrong wireType = %d for field O", wireType) + } + m.O = 0 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowGenerated + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + m.O |= (DataOrder(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + case 4: + if wireType != 0 { + return fmt.Errorf("proto: wrong wireType = %d for field T", wireType) + } + m.T = 0 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowGenerated + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + m.T |= (Triangle(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + case 5: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Type", wireType) + } + var stringLen uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowGenerated + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + stringLen |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + intStringLen := int(stringLen) + if intStringLen < 0 { + return ErrInvalidLengthGenerated + } + postIndex := iNdEx + intStringLen + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.Type = string(dAtA[iNdEx:postIndex]) + iNdEx = postIndex + case 6: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Data", wireType) + } + var byteLen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowGenerated + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + byteLen |= (int(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + if byteLen < 0 { + return ErrInvalidLengthGenerated + } + postIndex := iNdEx + byteLen + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.Data = append(m.Data[:0], dAtA[iNdEx:postIndex]...) + if m.Data == nil { + m.Data = []byte{} + } + iNdEx = postIndex + case 7: + if wireType == 0 { + var v int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowGenerated + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + v |= (int(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + m.Mask = append(m.Mask, bool(v != 0)) + } else if wireType == 2 { + var packedLen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowGenerated + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + packedLen |= (int(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + if packedLen < 0 { + return ErrInvalidLengthGenerated + } + postIndex := iNdEx + packedLen + if postIndex > l { + return io.ErrUnexpectedEOF + } + for iNdEx < postIndex { + var v int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowGenerated + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + v |= (int(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + m.Mask = append(m.Mask, bool(v != 0)) + } + } else { + return fmt.Errorf("proto: wrong wireType = %d for field Mask", wireType) + } + case 8: + if wireType == 0 { + var v int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowGenerated + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + v |= (int(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + m.MaskIsSoft = append(m.MaskIsSoft, bool(v != 0)) + } else if wireType == 2 { + var packedLen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowGenerated + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + packedLen |= (int(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + if packedLen < 0 { + return ErrInvalidLengthGenerated + } + postIndex := iNdEx + packedLen + if postIndex > l { + return io.ErrUnexpectedEOF + } + for iNdEx < postIndex { + var v int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowGenerated + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + v |= (int(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + m.MaskIsSoft = append(m.MaskIsSoft, bool(v != 0)) + } + } else { + return fmt.Errorf("proto: wrong wireType = %d for field MaskIsSoft", wireType) + } + default: + iNdEx = preIndex + skippy, err := skipGenerated(dAtA[iNdEx:]) + if err != nil { + return err + } + if skippy < 0 { + return ErrInvalidLengthGenerated + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + iNdEx += skippy + } + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} +func skipGenerated(dAtA []byte) (n int, err error) { + l := len(dAtA) + iNdEx := 0 + for iNdEx < l { + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return 0, ErrIntOverflowGenerated + } + if iNdEx >= l { + return 0, io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + wireType := int(wire & 0x7) + switch wireType { + case 0: + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return 0, ErrIntOverflowGenerated + } + if iNdEx >= l { + return 0, io.ErrUnexpectedEOF + } + iNdEx++ + if dAtA[iNdEx-1] < 0x80 { + break + } + } + return iNdEx, nil + case 1: + iNdEx += 8 + return iNdEx, nil + case 2: + var length int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return 0, ErrIntOverflowGenerated + } + if iNdEx >= l { + return 0, io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + length |= (int(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + iNdEx += length + if length < 0 { + return 0, ErrInvalidLengthGenerated + } + return iNdEx, nil + case 3: + for { + var innerWire uint64 + var start int = iNdEx + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return 0, ErrIntOverflowGenerated + } + if iNdEx >= l { + return 0, io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + innerWire |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + innerWireType := int(innerWire & 0x7) + if innerWireType == 4 { + break + } + next, err := skipGenerated(dAtA[start:]) + if err != nil { + return 0, err + } + iNdEx = start + next + } + return iNdEx, nil + case 4: + return iNdEx, nil + case 5: + iNdEx += 4 + return iNdEx, nil + default: + return 0, fmt.Errorf("proto: illegal wireType %d", wireType) + } + } + panic("unreachable") +} + +var ( + ErrInvalidLengthGenerated = fmt.Errorf("proto: negative length found during unmarshaling") + ErrIntOverflowGenerated = fmt.Errorf("proto: integer overflow") +) + +func init() { + proto.RegisterFile("gorgonia.org/tensor/internal/serialization/pb/generated.proto", fileDescriptorGenerated) +} + +var fileDescriptorGenerated = []byte{ + // 482 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xd4, 0x93, 0x4f, 0x6b, 0x13, 0x41, + 0x18, 0xc6, 0x77, 0x92, 0x6e, 0x9b, 0x4c, 0x13, 0x0d, 0x43, 0x0f, 0x43, 0x94, 0xcd, 0xd8, 0xd3, + 0x22, 0x74, 0x17, 0xf4, 0x20, 0x08, 0x1e, 0x5a, 0x1b, 0x24, 0x90, 0x7f, 0x4c, 0x52, 0x44, 0x2f, + 0x61, 0xb6, 0x3b, 0xd9, 0x0e, 0x4d, 0x76, 0x96, 0x99, 0x29, 0x52, 0xef, 0x42, 0xcd, 0x27, 0xf0, + 0x12, 0xa8, 0xda, 0x83, 0x1f, 0xc3, 0xa3, 0x17, 0xc1, 0x4f, 0x20, 0x92, 0x7e, 0x01, 0xcf, 0x9e, + 0x64, 0x27, 0x44, 0xe2, 0xd1, 0x9b, 0x3d, 0xcd, 0xf3, 0xfc, 0x66, 0x9e, 0x77, 0xde, 0x97, 0x61, + 0xe0, 0x93, 0x44, 0xaa, 0x44, 0xa6, 0x82, 0x05, 0x52, 0x25, 0xa1, 0xe1, 0xa9, 0x96, 0x2a, 0x14, + 0xa9, 0xe1, 0x2a, 0x65, 0x93, 0x50, 0x73, 0x25, 0xd8, 0x44, 0xbc, 0x66, 0x46, 0xc8, 0x34, 0xcc, + 0xa2, 0x30, 0xe1, 0x29, 0x57, 0xcc, 0xf0, 0x38, 0xc8, 0x94, 0x34, 0x12, 0xed, 0xad, 0xc7, 0x83, + 0x65, 0x3c, 0x58, 0xc5, 0x83, 0xbf, 0xe2, 0x41, 0x16, 0xd5, 0xf7, 0x12, 0x61, 0x4e, 0xce, 0xa2, + 0xe0, 0x58, 0x4e, 0xc3, 0x44, 0x26, 0x32, 0xb4, 0x55, 0xa2, 0xb3, 0xb1, 0x75, 0xd6, 0x58, 0xb5, + 0xac, 0xbe, 0xfb, 0x01, 0xc0, 0xc2, 0x7e, 0x1f, 0xed, 0x40, 0x57, 0x9f, 0xb0, 0x8c, 0x63, 0x40, + 0x8a, 0xbe, 0x4b, 0x97, 0x06, 0x61, 0xb8, 0xa5, 0x8d, 0x12, 0x31, 0xd7, 0xb8, 0x60, 0xf9, 0xca, + 0xa2, 0x3b, 0x10, 0x48, 0x5c, 0x24, 0xc0, 0xaf, 0x1e, 0x54, 0x7f, 0x7d, 0x6f, 0x94, 0x0f, 0x99, + 0x61, 0x3d, 0x15, 0x73, 0x45, 0x81, 0x44, 0x4d, 0x08, 0x0c, 0xde, 0x20, 0xc0, 0xbf, 0xf5, 0xe0, + 0x51, 0xf0, 0x4f, 0xdd, 0x07, 0x43, 0x25, 0x58, 0x9a, 0x4c, 0x38, 0x05, 0xe6, 0x71, 0xe9, 0xe2, + 0xb2, 0xe1, 0xfc, 0x7c, 0xdf, 0x70, 0x76, 0xbf, 0x02, 0xe8, 0x1e, 0xf2, 0x54, 0xf3, 0xff, 0xb1, + 0x4f, 0x84, 0xe0, 0x86, 0x39, 0xcf, 0x38, 0x76, 0x09, 0xf0, 0xcb, 0xd4, 0xea, 0x9c, 0xc5, 0xcc, + 0x30, 0xbc, 0x49, 0x80, 0x5f, 0xa1, 0x56, 0xaf, 0xcd, 0xf3, 0xb6, 0x00, 0xb7, 0x3b, 0x4c, 0x9f, + 0xf2, 0xf8, 0xc6, 0x4f, 0x95, 0xb3, 0x29, 0xd3, 0xa7, 0x78, 0x8b, 0x14, 0xfd, 0x12, 0xb5, 0x1a, + 0x11, 0x58, 0xc9, 0xd7, 0x91, 0xd0, 0x23, 0x2d, 0xc7, 0x06, 0x97, 0xec, 0x1e, 0xcc, 0x59, 0x4b, + 0x0f, 0xe4, 0x78, 0xed, 0x6d, 0xef, 0xbf, 0x01, 0xb0, 0xb4, 0xba, 0x17, 0xdd, 0x83, 0x95, 0x6e, + 0x6f, 0x38, 0x1a, 0xd2, 0xd6, 0x7e, 0xf7, 0x59, 0xbb, 0x59, 0x73, 0xea, 0xb7, 0x67, 0x73, 0xb2, + 0xdd, 0x95, 0xe6, 0xcf, 0x91, 0x1d, 0xe8, 0x1e, 0xf5, 0xfb, 0x4d, 0x5a, 0x03, 0xf5, 0xf2, 0x6c, + 0x4e, 0xdc, 0xa3, 0x2c, 0xe3, 0x2a, 0xa7, 0xed, 0xde, 0xf3, 0x26, 0xad, 0x15, 0x96, 0xb4, 0x2d, + 0x5f, 0x71, 0x85, 0xee, 0xc2, 0xf2, 0xe0, 0x45, 0xa7, 0xd3, 0x1c, 0xd2, 0xd6, 0xd3, 0x5a, 0xb1, + 0x5e, 0x9d, 0xcd, 0x49, 0x79, 0x70, 0x3e, 0x9d, 0x72, 0xa3, 0xc4, 0x71, 0xbd, 0x72, 0xf1, 0xd1, + 0x73, 0x3e, 0x5d, 0x79, 0xce, 0xe7, 0x2b, 0xcf, 0x39, 0xc0, 0x5f, 0x16, 0x1e, 0xf8, 0xb6, 0xf0, + 0xc0, 0x8f, 0x85, 0xe7, 0xbc, 0xbb, 0xf6, 0x9c, 0xcb, 0x6b, 0x0f, 0xbc, 0x2c, 0x64, 0x51, 0xb4, + 0x69, 0x7f, 0xca, 0xc3, 0xdf, 0x01, 0x00, 0x00, 0xff, 0xff, 0x0f, 0xff, 0xbb, 0x8f, 0xc8, 0x03, + 0x00, 0x00, +} diff --git a/internal/storage/consts.go b/internal/storage/consts.go new file mode 100644 index 0000000..b6e03cc --- /dev/null +++ b/internal/storage/consts.go @@ -0,0 +1,29 @@ +package storage + +import ( + "reflect" + "unsafe" +) + +// Code generated by genlib2. DO NOT EDIT. + +var ( + bType = reflect.TypeOf(bool(false)) + iType = reflect.TypeOf(int(0)) + i8Type = reflect.TypeOf(int8(0)) + i16Type = reflect.TypeOf(int16(0)) + i32Type = reflect.TypeOf(int32(0)) + i64Type = reflect.TypeOf(int64(0)) + uType = reflect.TypeOf(uint(0)) + u8Type = reflect.TypeOf(uint8(0)) + u16Type = reflect.TypeOf(uint16(0)) + u32Type = reflect.TypeOf(uint32(0)) + u64Type = reflect.TypeOf(uint64(0)) + uintptrType = reflect.TypeOf(uintptr(0)) + f32Type = reflect.TypeOf(float32(0)) + f64Type = reflect.TypeOf(float64(0)) + c64Type = reflect.TypeOf(complex64(0)) + c128Type = reflect.TypeOf(complex128(0)) + strType = reflect.TypeOf(string("")) + unsafePointerType = reflect.TypeOf(unsafe.Pointer(nil)) +) diff --git a/internal/storage/getset.go b/internal/storage/getset.go index 879a5e3..89421f0 100644 --- a/internal/storage/getset.go +++ b/internal/storage/getset.go @@ -1,113 +1,149 @@ -// Code generated by genlib2. DO NOT EDIT. - package storage import "unsafe" +// Code generated by genlib2. DO NOT EDIT. + /* bool */ -func (h *Header) Bools() []bool { return *(*[]bool)(unsafe.Pointer(h)) } +func (h *Header) Bools() []bool { + return (*(*[]bool)(unsafe.Pointer(&h.Raw)))[:h.TypedLen(bType):h.TypedLen(bType)] +} func (h *Header) SetB(i int, x bool) { h.Bools()[i] = x } func (h *Header) GetB(i int) bool { return h.Bools()[i] } /* int */ -func (h *Header) Ints() []int { return *(*[]int)(unsafe.Pointer(h)) } +func (h *Header) Ints() []int { + return (*(*[]int)(unsafe.Pointer(&h.Raw)))[:h.TypedLen(iType):h.TypedLen(iType)] +} func (h *Header) SetI(i int, x int) { h.Ints()[i] = x } func (h *Header) GetI(i int) int { return h.Ints()[i] } /* int8 */ -func (h *Header) Int8s() []int8 { return *(*[]int8)(unsafe.Pointer(h)) } +func (h *Header) Int8s() []int8 { + return (*(*[]int8)(unsafe.Pointer(&h.Raw)))[:h.TypedLen(i8Type):h.TypedLen(i8Type)] +} func (h *Header) SetI8(i int, x int8) { h.Int8s()[i] = x } func (h *Header) GetI8(i int) int8 { return h.Int8s()[i] } /* int16 */ -func (h *Header) Int16s() []int16 { return *(*[]int16)(unsafe.Pointer(h)) } +func (h *Header) Int16s() []int16 { + return (*(*[]int16)(unsafe.Pointer(&h.Raw)))[:h.TypedLen(i16Type):h.TypedLen(i16Type)] +} func (h *Header) SetI16(i int, x int16) { h.Int16s()[i] = x } func (h *Header) GetI16(i int) int16 { return h.Int16s()[i] } /* int32 */ -func (h *Header) Int32s() []int32 { return *(*[]int32)(unsafe.Pointer(h)) } +func (h *Header) Int32s() []int32 { + return (*(*[]int32)(unsafe.Pointer(&h.Raw)))[:h.TypedLen(i32Type):h.TypedLen(i32Type)] +} func (h *Header) SetI32(i int, x int32) { h.Int32s()[i] = x } func (h *Header) GetI32(i int) int32 { return h.Int32s()[i] } /* int64 */ -func (h *Header) Int64s() []int64 { return *(*[]int64)(unsafe.Pointer(h)) } +func (h *Header) Int64s() []int64 { + return (*(*[]int64)(unsafe.Pointer(&h.Raw)))[:h.TypedLen(i64Type):h.TypedLen(i64Type)] +} func (h *Header) SetI64(i int, x int64) { h.Int64s()[i] = x } func (h *Header) GetI64(i int) int64 { return h.Int64s()[i] } /* uint */ -func (h *Header) Uints() []uint { return *(*[]uint)(unsafe.Pointer(h)) } +func (h *Header) Uints() []uint { + return (*(*[]uint)(unsafe.Pointer(&h.Raw)))[:h.TypedLen(uType):h.TypedLen(uType)] +} func (h *Header) SetU(i int, x uint) { h.Uints()[i] = x } func (h *Header) GetU(i int) uint { return h.Uints()[i] } /* uint8 */ -func (h *Header) Uint8s() []uint8 { return *(*[]uint8)(unsafe.Pointer(h)) } +func (h *Header) Uint8s() []uint8 { + return (*(*[]uint8)(unsafe.Pointer(&h.Raw)))[:h.TypedLen(u8Type):h.TypedLen(u8Type)] +} func (h *Header) SetU8(i int, x uint8) { h.Uint8s()[i] = x } func (h *Header) GetU8(i int) uint8 { return h.Uint8s()[i] } /* uint16 */ -func (h *Header) Uint16s() []uint16 { return *(*[]uint16)(unsafe.Pointer(h)) } +func (h *Header) Uint16s() []uint16 { + return (*(*[]uint16)(unsafe.Pointer(&h.Raw)))[:h.TypedLen(u16Type):h.TypedLen(u16Type)] +} func (h *Header) SetU16(i int, x uint16) { h.Uint16s()[i] = x } func (h *Header) GetU16(i int) uint16 { return h.Uint16s()[i] } /* uint32 */ -func (h *Header) Uint32s() []uint32 { return *(*[]uint32)(unsafe.Pointer(h)) } +func (h *Header) Uint32s() []uint32 { + return (*(*[]uint32)(unsafe.Pointer(&h.Raw)))[:h.TypedLen(u32Type):h.TypedLen(u32Type)] +} func (h *Header) SetU32(i int, x uint32) { h.Uint32s()[i] = x } func (h *Header) GetU32(i int) uint32 { return h.Uint32s()[i] } /* uint64 */ -func (h *Header) Uint64s() []uint64 { return *(*[]uint64)(unsafe.Pointer(h)) } +func (h *Header) Uint64s() []uint64 { + return (*(*[]uint64)(unsafe.Pointer(&h.Raw)))[:h.TypedLen(u64Type):h.TypedLen(u64Type)] +} func (h *Header) SetU64(i int, x uint64) { h.Uint64s()[i] = x } func (h *Header) GetU64(i int) uint64 { return h.Uint64s()[i] } /* uintptr */ -func (h *Header) Uintptrs() []uintptr { return *(*[]uintptr)(unsafe.Pointer(h)) } +func (h *Header) Uintptrs() []uintptr { + return (*(*[]uintptr)(unsafe.Pointer(&h.Raw)))[:h.TypedLen(uintptrType):h.TypedLen(uintptrType)] +} func (h *Header) SetUintptr(i int, x uintptr) { h.Uintptrs()[i] = x } func (h *Header) GetUintptr(i int) uintptr { return h.Uintptrs()[i] } /* float32 */ -func (h *Header) Float32s() []float32 { return *(*[]float32)(unsafe.Pointer(h)) } +func (h *Header) Float32s() []float32 { + return (*(*[]float32)(unsafe.Pointer(&h.Raw)))[:h.TypedLen(f32Type):h.TypedLen(f32Type)] +} func (h *Header) SetF32(i int, x float32) { h.Float32s()[i] = x } func (h *Header) GetF32(i int) float32 { return h.Float32s()[i] } /* float64 */ -func (h *Header) Float64s() []float64 { return *(*[]float64)(unsafe.Pointer(h)) } +func (h *Header) Float64s() []float64 { + return (*(*[]float64)(unsafe.Pointer(&h.Raw)))[:h.TypedLen(f64Type):h.TypedLen(f64Type)] +} func (h *Header) SetF64(i int, x float64) { h.Float64s()[i] = x } func (h *Header) GetF64(i int) float64 { return h.Float64s()[i] } /* complex64 */ -func (h *Header) Complex64s() []complex64 { return *(*[]complex64)(unsafe.Pointer(h)) } +func (h *Header) Complex64s() []complex64 { + return (*(*[]complex64)(unsafe.Pointer(&h.Raw)))[:h.TypedLen(c64Type):h.TypedLen(c64Type)] +} func (h *Header) SetC64(i int, x complex64) { h.Complex64s()[i] = x } func (h *Header) GetC64(i int) complex64 { return h.Complex64s()[i] } /* complex128 */ -func (h *Header) Complex128s() []complex128 { return *(*[]complex128)(unsafe.Pointer(h)) } +func (h *Header) Complex128s() []complex128 { + return (*(*[]complex128)(unsafe.Pointer(&h.Raw)))[:h.TypedLen(c128Type):h.TypedLen(c128Type)] +} func (h *Header) SetC128(i int, x complex128) { h.Complex128s()[i] = x } func (h *Header) GetC128(i int) complex128 { return h.Complex128s()[i] } /* string */ -func (h *Header) Strings() []string { return *(*[]string)(unsafe.Pointer(h)) } +func (h *Header) Strings() []string { + return (*(*[]string)(unsafe.Pointer(&h.Raw)))[:h.TypedLen(strType):h.TypedLen(strType)] +} func (h *Header) SetStr(i int, x string) { h.Strings()[i] = x } func (h *Header) GetStr(i int) string { return h.Strings()[i] } /* unsafe.Pointer */ -func (h *Header) UnsafePointers() []unsafe.Pointer { return *(*[]unsafe.Pointer)(unsafe.Pointer(h)) } +func (h *Header) UnsafePointers() []unsafe.Pointer { + return (*(*[]unsafe.Pointer)(unsafe.Pointer(&h.Raw)))[:h.TypedLen(unsafePointerType):h.TypedLen(unsafePointerType)] +} func (h *Header) SetUnsafePointer(i int, x unsafe.Pointer) { h.UnsafePointers()[i] = x } func (h *Header) GetUnsafePointer(i int) unsafe.Pointer { return h.UnsafePointers()[i] } diff --git a/internal/storage/header.go b/internal/storage/header.go index 23eeb22..65f1069 100644 --- a/internal/storage/header.go +++ b/internal/storage/header.go @@ -3,49 +3,30 @@ package storage // import "gorgonia.org/tensor/internal/storage" import ( "reflect" "unsafe" + + _ "go4.org/unsafe/assume-no-moving-gc" ) // Header is runtime representation of a slice. It's a cleaner version of reflect.SliceHeader. // With this, we wouldn't need to keep the uintptr. // This usually means additional pressure for the GC though, especially when passing around Headers type Header struct { - Ptr unsafe.Pointer - L int - C int + Raw []byte } -func (h *Header) Pointer() unsafe.Pointer { return h.Ptr } -func (h *Header) Len() int { return h.L } +// TypedLen returns the length of data as if it was a slice of type t +func (h *Header) TypedLen(t reflect.Type) int { + return len(h.Raw) / int(t.Size()) +} func Copy(t reflect.Type, dst, src *Header) int { - if dst.L == 0 || src.L == 0 { - return 0 - } - - n := src.L - if dst.L < n { - n = dst.L - } - - // handle struct{} type - if t.Size() == 0 { - return n - } - - // memmove(dst.Pointer(), src.Pointer(), t.Size()) - // return n - - // otherwise, just copy bytes. - // FUTURE: implement memmove - dstBA := AsByteSlice(dst, t) - srcBA := AsByteSlice(src, t) - copied := copy(dstBA, srcBA) + copied := copy(dst.Raw, src.Raw) return copied / int(t.Size()) } func CopySliced(t reflect.Type, dst *Header, dstart, dend int, src *Header, sstart, send int) int { - dstBA := AsByteSlice(dst, t) - srcBA := AsByteSlice(src, t) + dstBA := dst.Raw + srcBA := src.Raw size := int(t.Size()) ds := dstart * size @@ -56,9 +37,32 @@ func CopySliced(t reflect.Type, dst *Header, dstart, dend int, src *Header, ssta return copied / size } +func SwapCopy(a, b *Header) { + for i := range a.Raw { + a.Raw[i], b.Raw[i] = b.Raw[i], a.Raw[i] + } +} + +func Fill(t reflect.Type, dst, src *Header) int { + dstBA := dst.Raw + srcBA := src.Raw + size := int(t.Size()) + lenSrc := len(srcBA) + + dstart := 0 + for { + copied := copy(dstBA[dstart:], srcBA) + dstart += copied + if copied < lenSrc { + break + } + } + return dstart / size +} + func CopyIter(t reflect.Type, dst, src *Header, diter, siter Iterator) int { - dstBA := AsByteSlice(dst, t) - srcBA := AsByteSlice(src, t) + dstBA := dst.Raw + srcBA := src.Raw size := int(t.Size()) var idx, jdx, i, j, count int @@ -85,17 +89,30 @@ func CopyIter(t reflect.Type, dst, src *Header, diter, siter Iterator) int { return count } -func AsByteSlice(a *Header, t reflect.Type) []byte { - size := a.L * int(t.Size()) +// Element gets the pointer of ith element +func ElementAt(i int, base unsafe.Pointer, typeSize uintptr) unsafe.Pointer { + return unsafe.Pointer(uintptr(base) + uintptr(i)*typeSize) +} + +// AsByteSlice takes a slice of anything and returns a casted-as-byte-slice view of it. +// This function panics if input is not a slice. +func AsByteSlice(x interface{}) []byte { + xV := reflect.ValueOf(x) + xT := reflect.TypeOf(x).Elem() // expects a []T + hdr := reflect.SliceHeader{ - Data: uintptr(a.Ptr), - Len: size, - Cap: size, + Data: xV.Pointer(), + Len: xV.Len() * int(xT.Size()), + Cap: xV.Cap() * int(xT.Size()), } return *(*[]byte)(unsafe.Pointer(&hdr)) } -// Element gets the pointer of ith element -func ElementAt(i int, base unsafe.Pointer, typeSize uintptr) unsafe.Pointer { - return unsafe.Pointer(uintptr(base) + uintptr(i)*typeSize) +func FromMemory(ptr uintptr, memsize uintptr) []byte { + hdr := reflect.SliceHeader{ + Data: ptr, + Len: int(memsize), + Cap: int(memsize), + } + return *(*[]byte)(unsafe.Pointer(&hdr)) } diff --git a/internal/storage/header_test.go b/internal/storage/header_test.go new file mode 100644 index 0000000..ab28ac5 --- /dev/null +++ b/internal/storage/header_test.go @@ -0,0 +1,70 @@ +package storage + +import ( + "github.com/stretchr/testify/assert" + "reflect" + "testing" +) + +func TestCopy(t *testing.T) { + // A longer than B + a := headerFromSlice([]int{0, 1, 2, 3, 4}) + b := headerFromSlice([]int{10, 11}) + copied := Copy(reflect.TypeOf(1), &a, &b) + + assert.Equal(t, 2, copied) + assert.Equal(t, []int{10, 11, 2, 3, 4}, a.Ints()) + + // B longer than A + a = headerFromSlice([]int{10, 11}) + b = headerFromSlice([]int{0, 1, 2, 3, 4}) + copied = Copy(reflect.TypeOf(1), &a, &b) + + assert.Equal(t, 2, copied) + assert.Equal(t, []int{0, 1}, a.Ints()) + + // A is empty + a = headerFromSlice([]int{}) + b = headerFromSlice([]int{0, 1, 2, 3, 4}) + copied = Copy(reflect.TypeOf(1), &a, &b) + + assert.Equal(t, 0, copied) + + // B is empty + a = headerFromSlice([]int{0, 1, 2, 3, 4}) + b = headerFromSlice([]int{}) + copied = Copy(reflect.TypeOf(1), &a, &b) + + assert.Equal(t, 0, copied) + assert.Equal(t, []int{0, 1, 2, 3, 4}, a.Ints()) +} + +func TestFill(t *testing.T) { + // A longer than B + a := headerFromSlice([]int{0, 1, 2, 3, 4}) + b := headerFromSlice([]int{10, 11}) + copied := Fill(reflect.TypeOf(1), &a, &b) + + assert.Equal(t, 5, copied) + assert.Equal(t, []int{10, 11, 10, 11, 10}, a.Ints()) + + // B longer than A + a = headerFromSlice([]int{10, 11}) + b = headerFromSlice([]int{0, 1, 2, 3, 4}) + copied = Fill(reflect.TypeOf(1), &a, &b) + + assert.Equal(t, 2, copied) + assert.Equal(t, []int{0, 1}, a.Ints()) +} + +func headerFromSlice(x interface{}) Header { + xT := reflect.TypeOf(x) + if xT.Kind() != reflect.Slice { + panic("Expected a slice") + } + xV := reflect.ValueOf(x) + size := uintptr(xV.Len()) * xT.Elem().Size() + return Header{ + Raw: FromMemory(xV.Pointer(), size), + } +} diff --git a/internal/storage/keepsync.go b/internal/storage/keepsync.go index f008e2a..dde26cd 100644 --- a/internal/storage/keepsync.go +++ b/internal/storage/keepsync.go @@ -19,18 +19,12 @@ type NoOpError interface { NoOp() bool } -type noopError struct{} - -func (e noopError) NoOp() bool { return true } -func (e noopError) Error() string { return "NoOp" } - func handleNoOp(err error) error { if err == nil { return nil } - - if _, ok := err.(NoOpError); !ok { - return err + if _, ok := err.(NoOpError); ok { + return nil } - return nil + return err } diff --git a/internal/storage/keepsync_test.go b/internal/storage/keepsync_test.go new file mode 100644 index 0000000..00b2182 --- /dev/null +++ b/internal/storage/keepsync_test.go @@ -0,0 +1,20 @@ +package storage + +import ( + "github.com/pkg/errors" + "github.com/stretchr/testify/assert" + "testing" +) + +type noopError struct{} + +func (e noopError) NoOp() bool { return true } +func (e noopError) Error() string { return "NoOp" } + +func TestHandleNoOp(t *testing.T) { + otherErr := errors.New("other error") + + assert.Equal(t, nil, handleNoOp(noopError{})) + assert.Equal(t, nil, handleNoOp(nil)) + assert.Equal(t, otherErr, handleNoOp(otherErr)) +} diff --git a/iterator.go b/iterator.go index 0801748..0fe0a5d 100644 --- a/iterator.go +++ b/iterator.go @@ -1,6 +1,8 @@ package tensor -import "runtime" +import ( + "runtime" +) func requiresOrderedIterator(e Engine, t Tensor) bool { if t.IsScalar() { @@ -70,7 +72,7 @@ func NewIterator(aps ...*AP) Iterator { case 0: return nil case 1: - return NewFlatIterator(aps[0]) + return newFlatIterator(aps[0]) default: return NewMultIterator(aps...) } @@ -111,43 +113,54 @@ func iteratorLoadAP(it Iterator, ap *AP) { /* FLAT ITERATOR */ -// FlatIterator is an iterator that iterates over Tensors. It utilizes the *AP -// of a Tensor to determine what the next index is. +// FlatIterator is an iterator that iterates over Tensors according to the data's layout. +// It utilizes the *AP of a Tensor to determine what the next index is. // This data structure is similar to Numpy's flatiter, with some standard Go based restrictions of course // (such as, not allowing negative indices) type FlatIterator struct { *AP //state - nextIndex int - lastIndex int - strides0 int - size int - track []int - done bool - reverse bool // if true, iterator starts at end of array and runs backwards -} - -// NewFlatIterator creates a new FlatIterator. -func NewFlatIterator(ap *AP) *FlatIterator { - var strides0 int - if ap.IsVector() { - strides0 = ap.strides[0] - } else if ap.o.isColMajor() { - strides0 = ap.strides[0] + track []int + nextIndex int + lastIndex int + size int + done bool + veclikeDim int // the dimension of a vectorlike shape that is not a 1. + reverse bool // if true, iterator starts at end of array and runs backwards + + isScalar bool + isVector bool + + outerFirst bool +} + +// newFlatIterator creates a new FlatIterator. +func newFlatIterator(ap *AP) *FlatIterator { + var dim int + if ap.IsVectorLike() { + for d, i := range ap.shape { + if i != 1 { + dim = d + break + } + } } return &FlatIterator{ - AP: ap, - track: make([]int, len(ap.shape)), - size: ap.shape.TotalSize(), - strides0: strides0, + AP: ap, + track: make([]int, len(ap.shape)), + size: ap.shape.TotalSize(), + veclikeDim: dim, + + isScalar: ap.IsScalar(), + isVector: ap.IsVectorLike(), } } // FlatIteratorFromDense creates a new FlatIterator from a dense tensor func FlatIteratorFromDense(tt DenseTensor) *FlatIterator { - return NewFlatIterator(tt.Info()) + return newFlatIterator(tt.Info()) } // SetReverse initializes iterator to run backwards @@ -182,10 +195,10 @@ func (it *FlatIterator) Next() (int, error) { } switch { - case it.IsScalar(): + case it.isScalar: it.done = true return 0, nil - case it.IsVector(): + case it.isVector: if it.reverse { return it.singlePrevious() } @@ -194,6 +207,9 @@ func (it *FlatIterator) Next() (int, error) { if it.reverse { return it.ndPrevious() } + if it.outerFirst { + return it.colMajorNDNext() + } return it.ndNext() } } @@ -212,10 +228,10 @@ func (it *FlatIterator) NextValid() (int, int, error) { return -1, 1, noopError{} } switch { - case it.IsScalar(): + case it.isScalar: it.done = true return 0, 0, nil - case it.IsVector(): + case it.isVector: if it.reverse { a, err := it.singlePrevious() return a, -1, err @@ -227,6 +243,11 @@ func (it *FlatIterator) NextValid() (int, int, error) { a, err := it.ndPrevious() return a, -1, err } + + if it.outerFirst { + a, err := it.colMajorNDNext() + return a, 1, err + } a, err := it.ndNext() return a, 1, err } @@ -245,20 +266,11 @@ func (it *FlatIterator) NextInvalid() (int, int, error) { func (it *FlatIterator) singleNext() (int, error) { it.lastIndex = it.nextIndex - // it.lastIndex += it.strides[0] - it.nextIndex += it.strides0 + it.nextIndex++ var tracked int - switch { - case it.IsRowVec(): - it.track[1]++ - tracked = it.track[1] - case it.IsColVec(), it.IsVector(): - it.track[0]++ - tracked = it.track[0] - default: - panic("This ain't supposed to happen") - } + it.track[it.veclikeDim]++ + tracked = it.track[it.veclikeDim] if tracked >= it.size { it.done = true @@ -269,48 +281,86 @@ func (it *FlatIterator) singleNext() (int, error) { func (it *FlatIterator) singlePrevious() (int, error) { it.lastIndex = it.nextIndex - // it.lastIndex += it.strides[0] - it.nextIndex -= it.strides0 + it.nextIndex-- var tracked int - switch { - case it.IsRowVec(): - it.track[1]-- - tracked = it.track[1] - case it.IsColVec(), it.IsVector(): - it.track[0]-- - tracked = it.track[0] - default: - panic("This ain't supposed to happen") - } + it.track[it.veclikeDim]-- + tracked = it.track[it.veclikeDim] if tracked < 0 { it.done = true } - return it.lastIndex, nil } func (it *FlatIterator) ndNext() (int, error) { - it.lastIndex = it.nextIndex - for i := len(it.shape) - 1; i >= 0; i-- { - it.track[i]++ - if it.track[i] == it.shape[i] { + // the reason for this weird looking bits of code is because the SSA compiler doesn't + // know how to optimize for this bit of code, not keeping things in registers correctly + // @stuartcarnie optimized this iout to great effect + + v := len(it.shape) - 1 + nextIndex := it.nextIndex + it.lastIndex = nextIndex + + // the following 3 lines causes the compiler to perform bounds check here, + // instead of being done in the loop + coord := it.shape[:v+1] + track := it.track[:v+1] + strides := it.strides[:v+1] + for i := v; i >= 0; i-- { + track[i]++ + shapeI := coord[i] + strideI := strides[i] + + if track[i] == shapeI { if i == 0 { it.done = true } - it.track[i] = 0 - it.nextIndex -= (it.shape[i] - 1) * it.strides[i] + track[i] = 0 + nextIndex -= (shapeI - 1) * strideI continue } - it.nextIndex += it.strides[i] + nextIndex += strideI break } + it.nextIndex = nextIndex return it.lastIndex, nil } func (it *FlatIterator) colMajorNDNext() (int, error) { - return 0, nil + // the reason for this weird looking bits of code is because the SSA compiler doesn't + // know how to optimize for this bit of code, not keeping things in registers correctly + // @stuartcarnie optimized this iout to great effect + + v := len(it.shape) - 1 + nextIndex := it.nextIndex + it.lastIndex = nextIndex + + // the following 3 lines causes the compiler to perform bounds check here, + // instead of being done in the loop + coord := it.shape[:v+1] + track := it.track[:v+1] + strides := it.strides[:v+1] + for i := 0; i <= v; i++ { + track[i]++ + shapeI := coord[i] + strideI := strides[i] + + if track[i] == shapeI { + if i == v { + it.done = true + } + track[i] = 0 + + nextIndex -= (shapeI - 1) * strideI + continue + } + nextIndex += strideI + break + } + it.nextIndex = nextIndex + return it.lastIndex, nil + } func (it *FlatIterator) ndPrevious() (int, error) { @@ -331,6 +381,7 @@ func (it *FlatIterator) ndPrevious() (int, error) { return it.lastIndex, nil } +// TODO v0.9.0 func (it *FlatIterator) colMajorNDPrevious() (int, error) { return 0, nil } @@ -338,9 +389,10 @@ func (it *FlatIterator) colMajorNDPrevious() (int, error) { // Coord returns the next coordinate. // When Next() is called, the coordinates are updated AFTER the Next() returned. // See example for more details. -func (it *FlatIterator) Coord() []int { - return it.track -} +// +// The returned coordinates is mutable. Changing any values in the return value will +// change the state of the iterator +func (it *FlatIterator) Coord() []int { return it.track } // Slice is a convenience function that augments func (it *FlatIterator) Slice(sli Slice) (retVal []int, err error) { @@ -401,10 +453,12 @@ func (it *FlatIterator) Reset() { switch { case it.IsScalar(): it.nextIndex = 0 - case it.IsRowVec(): - it.nextIndex = (it.shape[1] - 1) * it.strides[0] - case it.IsColVec(), it.IsVector(): + case it.isVector: it.nextIndex = (it.shape[0] - 1) * it.strides[0] + // case it.IsRowVec(): + // it.nextIndex = (it.shape[1] - 1) * it.strides[1] + // case it.IsColVec(): + // it.nextIndex = (it.shape[0] - 1) * it.strides[0] default: it.nextIndex = 0 for i := range it.track { diff --git a/iterator_axial.go b/iterator_axial.go new file mode 100644 index 0000000..493ac21 --- /dev/null +++ b/iterator_axial.go @@ -0,0 +1,170 @@ +package tensor + +// AxialIterator iterates based on a given axis +type AxialIterator struct { + *AP + axis int // the axis to iterate along + + // state + axisSz int // if an axis is of size N, then axisSz indicates the current num (0 - N). + nextIndex int + lastIndex int + track []int + isReverse bool + done bool + fixed bool +} + +// AxialIteratorFromDense creates and axial iterator that will iterate along the given axis. `fixedAxis` defines if the axisSz is fixed. +func AxialIteratorFromDense(t DenseTensor, axis, axisSz int, fixedAxis bool) *AxialIterator { + ap := t.Info() + return &AxialIterator{ + AP: ap, + track: make([]int, len(ap.shape)), + axis: axis, + axisSz: axisSz, + fixed: fixedAxis, + } +} + +// Start returns the first index +func (it *AxialIterator) Start() (retVal int, err error) { + it.Reset() + + // compute the nextIndex + if it.fixed { + it.track[it.axis] = it.axisSz + it.nextIndex, err = Ltoi(it.shape, it.strides, it.track...) + } + + return it.Next() +} + +// Next returns the next index. +// Example: let's say we're iterating on a tensor with the following +// shape: (2, 3, 4); axis: 1 +// At the start, the coordinates are: +// coordinates: (0, 0, 0) +// Next() will yield: +// coordinates: (0, 0, 1) +// But when the coordinates are: +// coordinates: (0, 0, 4) +// Next() will yield: +// coordinates: (1, 0, 0). +// Note that axis 1 is frozen at 0. +func (it *AxialIterator) Next() (int, error) { + if it.done { + return -1, noopError{} + } + + switch { + case it.isReverse: + return it.ndPrevious() + default: + return it.ndNext() + } + +} + +func (it *AxialIterator) ndNext() (int, error) { + v := len(it.shape) - 1 + nextIndex := it.nextIndex + it.lastIndex = nextIndex + + track := it.track[:v+1] // force bounds check + coord := it.shape[:v+1] // force bounds check + strides := it.strides[:v+1] // fource bounds check + sz := it.axisSz + track[it.axis] = sz + + for i := v; i >= 0; i-- { + if i == it.axis { + if i == 0 { + if it.fixed || track[it.axis] == coord[it.axis] || it.axisSz >= coord[it.axis] { + track[it.axis] = 0 + it.done = true + break + } + it.axisSz++ + track[it.axis] = it.axisSz + } + continue // we're iterating along an axis. + } + track[i]++ + shapeI := coord[i] + strideI := strides[i] + if track[i] == shapeI { + track[i] = 0 + nextIndex -= (shapeI - 1) * strideI + if i == 0 { + it.axisSz++ + track[it.axis] = it.axisSz + + if it.fixed || track[it.axis] == coord[it.axis] || it.axisSz >= coord[it.axis] { + track[it.axis] = 0 + it.done = true + break + } + + nextIndex = track[it.axis] * strides[it.axis] + } + + continue + } + nextIndex += strideI + break + } + it.nextIndex = nextIndex + return it.lastIndex, nil +} + +func (it *AxialIterator) ndPrevious() (int, error) { + panic("Not yet implemented") +} + +// NextValidity is like Next, but returns the validity of the value at the index as well. +func (it *AxialIterator) NextValidity() (int, bool, error) { + i, err := it.Next() + return i, true, err +} + +// NextValid returns the next valid index, as well as a skip count. +func (it *AxialIterator) NextValid() (int, int, error) { + if it.done { + return -1, 1, noopError{} + } + + switch { + case it.isReverse: + a, err := it.ndPrevious() + return a, -1, err + default: + a, err := it.ndNext() + return a, 1, err + } +} + +// NextInvalid returns the next invalid index, as well as a skip count. +func (it *AxialIterator) NextInvalid() (int, int, error) { + panic("not implemented") // TODO: Implement +} + +// Reset resets the iterator +func (it *AxialIterator) Reset() { + it.nextIndex = 0 + for i := range it.track { + it.track[i] = 0 + } +} + +// SetReverse tells the iterator to iterate in reverse +func (it *AxialIterator) SetReverse() { it.isReverse = true } + +// SetForward tells the iterator to iterate forwards +func (it *AxialIterator) SetForward() { it.isReverse = false } + +// Coord returns the coordinates +func (it *AxialIterator) Coord() []int { return it.track } + +// Done returns true when the iterator is done iterating. +func (it *AxialIterator) Done() bool { return it.done } diff --git a/iterator_mult.go b/iterator_mult.go index a0af458..74f9a4b 100644 --- a/iterator_mult.go +++ b/iterator_mult.go @@ -97,16 +97,19 @@ func NewMultIterator(aps ...*AP) *MultIterator { ReturnInts(apStrides) // Borrowed in BroadcastStrides but returned here - dangerous pattern? nBlocks++ } - ap2 := NewAP(it.shape[:maxDims], it.strides[offset:offset+maxDims]) - ap2.o = ap.o - ap2.Δ = ap.Δ - + ap2 := MakeAP(it.shape[:maxDims], it.strides[offset:offset+maxDims], ap.o, ap.Δ) it.whichBlock[i] = f - it.fitArr[nBlocks-1] = NewFlatIterator(ap2) + it.fitArr[nBlocks-1] = newFlatIterator(&ap2) } it.fitArr = it.fitArr[:nBlocks] it.strides = it.strides[:nBlocks*maxDims] + // fill 0s with 1s + for i := range it.strides { + if it.strides[i] == 0 { + it.strides[i] = 1 + } + } it.fit0 = it.fitArr[0] for _, f := range it.fitArr { @@ -120,7 +123,7 @@ func NewMultIterator(aps ...*AP) *MultIterator { // MultIteratorFromDense creates a new MultIterator from a list of dense tensors func MultIteratorFromDense(tts ...DenseTensor) *MultIterator { - aps := BorrowAPList(len(tts)) + aps := make([]*AP, len(tts)) hasMask := BorrowBools(len(tts)) defer ReturnBools(hasMask) @@ -155,7 +158,6 @@ func MultIteratorFromDense(tts ...DenseTensor) *MultIterator { } } it.numMasked = numMasked - ReturnAPList(aps) return it } @@ -221,7 +223,9 @@ func (it *MultIterator) Next() (int, error) { } it.done = false for _, f := range it.fitArr { - f.Next() + if _, err := f.Next(); err != nil { + return -1, err + } it.done = it.done || f.done } for i, j := range it.whichBlock { diff --git a/iterator_native.go b/iterator_native.go new file mode 100644 index 0000000..470891d --- /dev/null +++ b/iterator_native.go @@ -0,0 +1,1152 @@ +package tensor + +import ( + "reflect" + "unsafe" + + "github.com/pkg/errors" + "gorgonia.org/dtype" +) + +// Code generated by genlib2. DO NOT EDIT. + +func checkNativeIterable(t *Dense, dims int, dt dtype.Dtype) error { + // checks: + if !t.IsNativelyAccessible() { + return errors.Errorf("Cannot convert *Dense to *mat.Dense. Data is inaccessible") + } + + if t.Shape().Dims() != dims { + return errors.Errorf("Cannot convert *Dense to native iterator. Expected number of dimension: %d, T has got %d dimensions (Shape: %v)", dims, t.Dims(), t.Shape()) + } + + if t.F() || t.RequiresIterator() { + return errors.Errorf("Not yet implemented: native matrix for colmajor or unpacked matrices") + } + + if t.Dtype() != dt { + return errors.Errorf("Conversion to native iterable only works on %v. Got %v", dt, t.Dtype()) + } + + return nil +} + +/* Native Iterables for bool */ + +// nativeDenseVectorB converts a *Dense into a []bool +// If the *Dense does not represent a vector of the wanted type, it will return +// an error. +func nativeDenseVectorB(t *Dense) (retVal []bool, err error) { + if err = checkNativeIterable(t, 1, Bool); err != nil { + return nil, err + } + return t.Bools(), nil +} + +// nativeDenseMatrixB converts a *Dense into a [][]bool +// If the *Dense does not represent a matrix of the wanted type, it +// will return an error. +func nativeDenseMatrixB(t *Dense) (retVal [][]bool, err error) { + if err = checkNativeIterable(t, 2, Bool); err != nil { + return nil, err + } + + data := t.Bools() + shape := t.Shape() + strides := t.Strides() + + rows := shape[0] + cols := shape[1] + rowStride := strides[0] + retVal = make([][]bool, rows) + for i := range retVal { + start := i * rowStride + retVal[i] = make([]bool, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + return +} + +// nativeDenseTensor3B converts a *Dense into a [][][]bool. +// If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. +func nativeDenseTensor3B(t *Dense) (retVal [][][]bool, err error) { + if err = checkNativeIterable(t, 3, Bool); err != nil { + return nil, err + } + + data := t.Bools() + shape := t.Shape() + strides := t.Strides() + + layers := shape[0] + rows := shape[1] + cols := shape[2] + layerStride := strides[0] + rowStride := strides[1] + retVal = make([][][]bool, layers) + for i := range retVal { + retVal[i] = make([][]bool, rows) + for j := range retVal[i] { + retVal[i][j] = make([]bool, 0) + start := i*layerStride + j*rowStride + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + } + return +} + +/* Native Iterables for int */ + +// nativeDenseVectorI converts a *Dense into a []int +// If the *Dense does not represent a vector of the wanted type, it will return +// an error. +func nativeDenseVectorI(t *Dense) (retVal []int, err error) { + if err = checkNativeIterable(t, 1, Int); err != nil { + return nil, err + } + return t.Ints(), nil +} + +// nativeDenseMatrixI converts a *Dense into a [][]int +// If the *Dense does not represent a matrix of the wanted type, it +// will return an error. +func nativeDenseMatrixI(t *Dense) (retVal [][]int, err error) { + if err = checkNativeIterable(t, 2, Int); err != nil { + return nil, err + } + + data := t.Ints() + shape := t.Shape() + strides := t.Strides() + + rows := shape[0] + cols := shape[1] + rowStride := strides[0] + retVal = make([][]int, rows) + for i := range retVal { + start := i * rowStride + retVal[i] = make([]int, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + return +} + +// nativeDenseTensor3I converts a *Dense into a [][][]int. +// If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. +func nativeDenseTensor3I(t *Dense) (retVal [][][]int, err error) { + if err = checkNativeIterable(t, 3, Int); err != nil { + return nil, err + } + + data := t.Ints() + shape := t.Shape() + strides := t.Strides() + + layers := shape[0] + rows := shape[1] + cols := shape[2] + layerStride := strides[0] + rowStride := strides[1] + retVal = make([][][]int, layers) + for i := range retVal { + retVal[i] = make([][]int, rows) + for j := range retVal[i] { + retVal[i][j] = make([]int, 0) + start := i*layerStride + j*rowStride + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + } + return +} + +/* Native Iterables for int8 */ + +// nativeDenseVectorI8 converts a *Dense into a []int8 +// If the *Dense does not represent a vector of the wanted type, it will return +// an error. +func nativeDenseVectorI8(t *Dense) (retVal []int8, err error) { + if err = checkNativeIterable(t, 1, Int8); err != nil { + return nil, err + } + return t.Int8s(), nil +} + +// nativeDenseMatrixI8 converts a *Dense into a [][]int8 +// If the *Dense does not represent a matrix of the wanted type, it +// will return an error. +func nativeDenseMatrixI8(t *Dense) (retVal [][]int8, err error) { + if err = checkNativeIterable(t, 2, Int8); err != nil { + return nil, err + } + + data := t.Int8s() + shape := t.Shape() + strides := t.Strides() + + rows := shape[0] + cols := shape[1] + rowStride := strides[0] + retVal = make([][]int8, rows) + for i := range retVal { + start := i * rowStride + retVal[i] = make([]int8, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + return +} + +// nativeDenseTensor3I8 converts a *Dense into a [][][]int8. +// If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. +func nativeDenseTensor3I8(t *Dense) (retVal [][][]int8, err error) { + if err = checkNativeIterable(t, 3, Int8); err != nil { + return nil, err + } + + data := t.Int8s() + shape := t.Shape() + strides := t.Strides() + + layers := shape[0] + rows := shape[1] + cols := shape[2] + layerStride := strides[0] + rowStride := strides[1] + retVal = make([][][]int8, layers) + for i := range retVal { + retVal[i] = make([][]int8, rows) + for j := range retVal[i] { + retVal[i][j] = make([]int8, 0) + start := i*layerStride + j*rowStride + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + } + return +} + +/* Native Iterables for int16 */ + +// nativeDenseVectorI16 converts a *Dense into a []int16 +// If the *Dense does not represent a vector of the wanted type, it will return +// an error. +func nativeDenseVectorI16(t *Dense) (retVal []int16, err error) { + if err = checkNativeIterable(t, 1, Int16); err != nil { + return nil, err + } + return t.Int16s(), nil +} + +// nativeDenseMatrixI16 converts a *Dense into a [][]int16 +// If the *Dense does not represent a matrix of the wanted type, it +// will return an error. +func nativeDenseMatrixI16(t *Dense) (retVal [][]int16, err error) { + if err = checkNativeIterable(t, 2, Int16); err != nil { + return nil, err + } + + data := t.Int16s() + shape := t.Shape() + strides := t.Strides() + + rows := shape[0] + cols := shape[1] + rowStride := strides[0] + retVal = make([][]int16, rows) + for i := range retVal { + start := i * rowStride + retVal[i] = make([]int16, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + return +} + +// nativeDenseTensor3I16 converts a *Dense into a [][][]int16. +// If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. +func nativeDenseTensor3I16(t *Dense) (retVal [][][]int16, err error) { + if err = checkNativeIterable(t, 3, Int16); err != nil { + return nil, err + } + + data := t.Int16s() + shape := t.Shape() + strides := t.Strides() + + layers := shape[0] + rows := shape[1] + cols := shape[2] + layerStride := strides[0] + rowStride := strides[1] + retVal = make([][][]int16, layers) + for i := range retVal { + retVal[i] = make([][]int16, rows) + for j := range retVal[i] { + retVal[i][j] = make([]int16, 0) + start := i*layerStride + j*rowStride + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + } + return +} + +/* Native Iterables for int32 */ + +// nativeDenseVectorI32 converts a *Dense into a []int32 +// If the *Dense does not represent a vector of the wanted type, it will return +// an error. +func nativeDenseVectorI32(t *Dense) (retVal []int32, err error) { + if err = checkNativeIterable(t, 1, Int32); err != nil { + return nil, err + } + return t.Int32s(), nil +} + +// nativeDenseMatrixI32 converts a *Dense into a [][]int32 +// If the *Dense does not represent a matrix of the wanted type, it +// will return an error. +func nativeDenseMatrixI32(t *Dense) (retVal [][]int32, err error) { + if err = checkNativeIterable(t, 2, Int32); err != nil { + return nil, err + } + + data := t.Int32s() + shape := t.Shape() + strides := t.Strides() + + rows := shape[0] + cols := shape[1] + rowStride := strides[0] + retVal = make([][]int32, rows) + for i := range retVal { + start := i * rowStride + retVal[i] = make([]int32, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + return +} + +// nativeDenseTensor3I32 converts a *Dense into a [][][]int32. +// If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. +func nativeDenseTensor3I32(t *Dense) (retVal [][][]int32, err error) { + if err = checkNativeIterable(t, 3, Int32); err != nil { + return nil, err + } + + data := t.Int32s() + shape := t.Shape() + strides := t.Strides() + + layers := shape[0] + rows := shape[1] + cols := shape[2] + layerStride := strides[0] + rowStride := strides[1] + retVal = make([][][]int32, layers) + for i := range retVal { + retVal[i] = make([][]int32, rows) + for j := range retVal[i] { + retVal[i][j] = make([]int32, 0) + start := i*layerStride + j*rowStride + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + } + return +} + +/* Native Iterables for int64 */ + +// nativeDenseVectorI64 converts a *Dense into a []int64 +// If the *Dense does not represent a vector of the wanted type, it will return +// an error. +func nativeDenseVectorI64(t *Dense) (retVal []int64, err error) { + if err = checkNativeIterable(t, 1, Int64); err != nil { + return nil, err + } + return t.Int64s(), nil +} + +// nativeDenseMatrixI64 converts a *Dense into a [][]int64 +// If the *Dense does not represent a matrix of the wanted type, it +// will return an error. +func nativeDenseMatrixI64(t *Dense) (retVal [][]int64, err error) { + if err = checkNativeIterable(t, 2, Int64); err != nil { + return nil, err + } + + data := t.Int64s() + shape := t.Shape() + strides := t.Strides() + + rows := shape[0] + cols := shape[1] + rowStride := strides[0] + retVal = make([][]int64, rows) + for i := range retVal { + start := i * rowStride + retVal[i] = make([]int64, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + return +} + +// nativeDenseTensor3I64 converts a *Dense into a [][][]int64. +// If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. +func nativeDenseTensor3I64(t *Dense) (retVal [][][]int64, err error) { + if err = checkNativeIterable(t, 3, Int64); err != nil { + return nil, err + } + + data := t.Int64s() + shape := t.Shape() + strides := t.Strides() + + layers := shape[0] + rows := shape[1] + cols := shape[2] + layerStride := strides[0] + rowStride := strides[1] + retVal = make([][][]int64, layers) + for i := range retVal { + retVal[i] = make([][]int64, rows) + for j := range retVal[i] { + retVal[i][j] = make([]int64, 0) + start := i*layerStride + j*rowStride + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + } + return +} + +/* Native Iterables for uint */ + +// nativeDenseVectorU converts a *Dense into a []uint +// If the *Dense does not represent a vector of the wanted type, it will return +// an error. +func nativeDenseVectorU(t *Dense) (retVal []uint, err error) { + if err = checkNativeIterable(t, 1, Uint); err != nil { + return nil, err + } + return t.Uints(), nil +} + +// nativeDenseMatrixU converts a *Dense into a [][]uint +// If the *Dense does not represent a matrix of the wanted type, it +// will return an error. +func nativeDenseMatrixU(t *Dense) (retVal [][]uint, err error) { + if err = checkNativeIterable(t, 2, Uint); err != nil { + return nil, err + } + + data := t.Uints() + shape := t.Shape() + strides := t.Strides() + + rows := shape[0] + cols := shape[1] + rowStride := strides[0] + retVal = make([][]uint, rows) + for i := range retVal { + start := i * rowStride + retVal[i] = make([]uint, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + return +} + +// nativeDenseTensor3U converts a *Dense into a [][][]uint. +// If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. +func nativeDenseTensor3U(t *Dense) (retVal [][][]uint, err error) { + if err = checkNativeIterable(t, 3, Uint); err != nil { + return nil, err + } + + data := t.Uints() + shape := t.Shape() + strides := t.Strides() + + layers := shape[0] + rows := shape[1] + cols := shape[2] + layerStride := strides[0] + rowStride := strides[1] + retVal = make([][][]uint, layers) + for i := range retVal { + retVal[i] = make([][]uint, rows) + for j := range retVal[i] { + retVal[i][j] = make([]uint, 0) + start := i*layerStride + j*rowStride + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + } + return +} + +/* Native Iterables for uint8 */ + +// nativeDenseVectorU8 converts a *Dense into a []uint8 +// If the *Dense does not represent a vector of the wanted type, it will return +// an error. +func nativeDenseVectorU8(t *Dense) (retVal []uint8, err error) { + if err = checkNativeIterable(t, 1, Uint8); err != nil { + return nil, err + } + return t.Uint8s(), nil +} + +// nativeDenseMatrixU8 converts a *Dense into a [][]uint8 +// If the *Dense does not represent a matrix of the wanted type, it +// will return an error. +func nativeDenseMatrixU8(t *Dense) (retVal [][]uint8, err error) { + if err = checkNativeIterable(t, 2, Uint8); err != nil { + return nil, err + } + + data := t.Uint8s() + shape := t.Shape() + strides := t.Strides() + + rows := shape[0] + cols := shape[1] + rowStride := strides[0] + retVal = make([][]uint8, rows) + for i := range retVal { + start := i * rowStride + retVal[i] = make([]uint8, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + return +} + +// nativeDenseTensor3U8 converts a *Dense into a [][][]uint8. +// If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. +func nativeDenseTensor3U8(t *Dense) (retVal [][][]uint8, err error) { + if err = checkNativeIterable(t, 3, Uint8); err != nil { + return nil, err + } + + data := t.Uint8s() + shape := t.Shape() + strides := t.Strides() + + layers := shape[0] + rows := shape[1] + cols := shape[2] + layerStride := strides[0] + rowStride := strides[1] + retVal = make([][][]uint8, layers) + for i := range retVal { + retVal[i] = make([][]uint8, rows) + for j := range retVal[i] { + retVal[i][j] = make([]uint8, 0) + start := i*layerStride + j*rowStride + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + } + return +} + +/* Native Iterables for uint16 */ + +// nativeDenseVectorU16 converts a *Dense into a []uint16 +// If the *Dense does not represent a vector of the wanted type, it will return +// an error. +func nativeDenseVectorU16(t *Dense) (retVal []uint16, err error) { + if err = checkNativeIterable(t, 1, Uint16); err != nil { + return nil, err + } + return t.Uint16s(), nil +} + +// nativeDenseMatrixU16 converts a *Dense into a [][]uint16 +// If the *Dense does not represent a matrix of the wanted type, it +// will return an error. +func nativeDenseMatrixU16(t *Dense) (retVal [][]uint16, err error) { + if err = checkNativeIterable(t, 2, Uint16); err != nil { + return nil, err + } + + data := t.Uint16s() + shape := t.Shape() + strides := t.Strides() + + rows := shape[0] + cols := shape[1] + rowStride := strides[0] + retVal = make([][]uint16, rows) + for i := range retVal { + start := i * rowStride + retVal[i] = make([]uint16, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + return +} + +// nativeDenseTensor3U16 converts a *Dense into a [][][]uint16. +// If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. +func nativeDenseTensor3U16(t *Dense) (retVal [][][]uint16, err error) { + if err = checkNativeIterable(t, 3, Uint16); err != nil { + return nil, err + } + + data := t.Uint16s() + shape := t.Shape() + strides := t.Strides() + + layers := shape[0] + rows := shape[1] + cols := shape[2] + layerStride := strides[0] + rowStride := strides[1] + retVal = make([][][]uint16, layers) + for i := range retVal { + retVal[i] = make([][]uint16, rows) + for j := range retVal[i] { + retVal[i][j] = make([]uint16, 0) + start := i*layerStride + j*rowStride + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + } + return +} + +/* Native Iterables for uint32 */ + +// nativeDenseVectorU32 converts a *Dense into a []uint32 +// If the *Dense does not represent a vector of the wanted type, it will return +// an error. +func nativeDenseVectorU32(t *Dense) (retVal []uint32, err error) { + if err = checkNativeIterable(t, 1, Uint32); err != nil { + return nil, err + } + return t.Uint32s(), nil +} + +// nativeDenseMatrixU32 converts a *Dense into a [][]uint32 +// If the *Dense does not represent a matrix of the wanted type, it +// will return an error. +func nativeDenseMatrixU32(t *Dense) (retVal [][]uint32, err error) { + if err = checkNativeIterable(t, 2, Uint32); err != nil { + return nil, err + } + + data := t.Uint32s() + shape := t.Shape() + strides := t.Strides() + + rows := shape[0] + cols := shape[1] + rowStride := strides[0] + retVal = make([][]uint32, rows) + for i := range retVal { + start := i * rowStride + retVal[i] = make([]uint32, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + return +} + +// nativeDenseTensor3U32 converts a *Dense into a [][][]uint32. +// If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. +func nativeDenseTensor3U32(t *Dense) (retVal [][][]uint32, err error) { + if err = checkNativeIterable(t, 3, Uint32); err != nil { + return nil, err + } + + data := t.Uint32s() + shape := t.Shape() + strides := t.Strides() + + layers := shape[0] + rows := shape[1] + cols := shape[2] + layerStride := strides[0] + rowStride := strides[1] + retVal = make([][][]uint32, layers) + for i := range retVal { + retVal[i] = make([][]uint32, rows) + for j := range retVal[i] { + retVal[i][j] = make([]uint32, 0) + start := i*layerStride + j*rowStride + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + } + return +} + +/* Native Iterables for uint64 */ + +// nativeDenseVectorU64 converts a *Dense into a []uint64 +// If the *Dense does not represent a vector of the wanted type, it will return +// an error. +func nativeDenseVectorU64(t *Dense) (retVal []uint64, err error) { + if err = checkNativeIterable(t, 1, Uint64); err != nil { + return nil, err + } + return t.Uint64s(), nil +} + +// nativeDenseMatrixU64 converts a *Dense into a [][]uint64 +// If the *Dense does not represent a matrix of the wanted type, it +// will return an error. +func nativeDenseMatrixU64(t *Dense) (retVal [][]uint64, err error) { + if err = checkNativeIterable(t, 2, Uint64); err != nil { + return nil, err + } + + data := t.Uint64s() + shape := t.Shape() + strides := t.Strides() + + rows := shape[0] + cols := shape[1] + rowStride := strides[0] + retVal = make([][]uint64, rows) + for i := range retVal { + start := i * rowStride + retVal[i] = make([]uint64, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + return +} + +// nativeDenseTensor3U64 converts a *Dense into a [][][]uint64. +// If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. +func nativeDenseTensor3U64(t *Dense) (retVal [][][]uint64, err error) { + if err = checkNativeIterable(t, 3, Uint64); err != nil { + return nil, err + } + + data := t.Uint64s() + shape := t.Shape() + strides := t.Strides() + + layers := shape[0] + rows := shape[1] + cols := shape[2] + layerStride := strides[0] + rowStride := strides[1] + retVal = make([][][]uint64, layers) + for i := range retVal { + retVal[i] = make([][]uint64, rows) + for j := range retVal[i] { + retVal[i][j] = make([]uint64, 0) + start := i*layerStride + j*rowStride + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + } + return +} + +/* Native Iterables for float32 */ + +// nativeDenseVectorF32 converts a *Dense into a []float32 +// If the *Dense does not represent a vector of the wanted type, it will return +// an error. +func nativeDenseVectorF32(t *Dense) (retVal []float32, err error) { + if err = checkNativeIterable(t, 1, Float32); err != nil { + return nil, err + } + return t.Float32s(), nil +} + +// nativeDenseMatrixF32 converts a *Dense into a [][]float32 +// If the *Dense does not represent a matrix of the wanted type, it +// will return an error. +func nativeDenseMatrixF32(t *Dense) (retVal [][]float32, err error) { + if err = checkNativeIterable(t, 2, Float32); err != nil { + return nil, err + } + + data := t.Float32s() + shape := t.Shape() + strides := t.Strides() + + rows := shape[0] + cols := shape[1] + rowStride := strides[0] + retVal = make([][]float32, rows) + for i := range retVal { + start := i * rowStride + retVal[i] = make([]float32, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + return +} + +// nativeDenseTensor3F32 converts a *Dense into a [][][]float32. +// If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. +func nativeDenseTensor3F32(t *Dense) (retVal [][][]float32, err error) { + if err = checkNativeIterable(t, 3, Float32); err != nil { + return nil, err + } + + data := t.Float32s() + shape := t.Shape() + strides := t.Strides() + + layers := shape[0] + rows := shape[1] + cols := shape[2] + layerStride := strides[0] + rowStride := strides[1] + retVal = make([][][]float32, layers) + for i := range retVal { + retVal[i] = make([][]float32, rows) + for j := range retVal[i] { + retVal[i][j] = make([]float32, 0) + start := i*layerStride + j*rowStride + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + } + return +} + +/* Native Iterables for float64 */ + +// nativeDenseVectorF64 converts a *Dense into a []float64 +// If the *Dense does not represent a vector of the wanted type, it will return +// an error. +func nativeDenseVectorF64(t *Dense) (retVal []float64, err error) { + if err = checkNativeIterable(t, 1, Float64); err != nil { + return nil, err + } + return t.Float64s(), nil +} + +// nativeDenseMatrixF64 converts a *Dense into a [][]float64 +// If the *Dense does not represent a matrix of the wanted type, it +// will return an error. +func nativeDenseMatrixF64(t *Dense) (retVal [][]float64, err error) { + if err = checkNativeIterable(t, 2, Float64); err != nil { + return nil, err + } + + data := t.Float64s() + shape := t.Shape() + strides := t.Strides() + + rows := shape[0] + cols := shape[1] + rowStride := strides[0] + retVal = make([][]float64, rows) + for i := range retVal { + start := i * rowStride + retVal[i] = make([]float64, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + return +} + +// nativeDenseTensor3F64 converts a *Dense into a [][][]float64. +// If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. +func nativeDenseTensor3F64(t *Dense) (retVal [][][]float64, err error) { + if err = checkNativeIterable(t, 3, Float64); err != nil { + return nil, err + } + + data := t.Float64s() + shape := t.Shape() + strides := t.Strides() + + layers := shape[0] + rows := shape[1] + cols := shape[2] + layerStride := strides[0] + rowStride := strides[1] + retVal = make([][][]float64, layers) + for i := range retVal { + retVal[i] = make([][]float64, rows) + for j := range retVal[i] { + retVal[i][j] = make([]float64, 0) + start := i*layerStride + j*rowStride + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + } + return +} + +/* Native Iterables for complex64 */ + +// nativeDenseVectorC64 converts a *Dense into a []complex64 +// If the *Dense does not represent a vector of the wanted type, it will return +// an error. +func nativeDenseVectorC64(t *Dense) (retVal []complex64, err error) { + if err = checkNativeIterable(t, 1, Complex64); err != nil { + return nil, err + } + return t.Complex64s(), nil +} + +// nativeDenseMatrixC64 converts a *Dense into a [][]complex64 +// If the *Dense does not represent a matrix of the wanted type, it +// will return an error. +func nativeDenseMatrixC64(t *Dense) (retVal [][]complex64, err error) { + if err = checkNativeIterable(t, 2, Complex64); err != nil { + return nil, err + } + + data := t.Complex64s() + shape := t.Shape() + strides := t.Strides() + + rows := shape[0] + cols := shape[1] + rowStride := strides[0] + retVal = make([][]complex64, rows) + for i := range retVal { + start := i * rowStride + retVal[i] = make([]complex64, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + return +} + +// nativeDenseTensor3C64 converts a *Dense into a [][][]complex64. +// If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. +func nativeDenseTensor3C64(t *Dense) (retVal [][][]complex64, err error) { + if err = checkNativeIterable(t, 3, Complex64); err != nil { + return nil, err + } + + data := t.Complex64s() + shape := t.Shape() + strides := t.Strides() + + layers := shape[0] + rows := shape[1] + cols := shape[2] + layerStride := strides[0] + rowStride := strides[1] + retVal = make([][][]complex64, layers) + for i := range retVal { + retVal[i] = make([][]complex64, rows) + for j := range retVal[i] { + retVal[i][j] = make([]complex64, 0) + start := i*layerStride + j*rowStride + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + } + return +} + +/* Native Iterables for complex128 */ + +// nativeDenseVectorC128 converts a *Dense into a []complex128 +// If the *Dense does not represent a vector of the wanted type, it will return +// an error. +func nativeDenseVectorC128(t *Dense) (retVal []complex128, err error) { + if err = checkNativeIterable(t, 1, Complex128); err != nil { + return nil, err + } + return t.Complex128s(), nil +} + +// nativeDenseMatrixC128 converts a *Dense into a [][]complex128 +// If the *Dense does not represent a matrix of the wanted type, it +// will return an error. +func nativeDenseMatrixC128(t *Dense) (retVal [][]complex128, err error) { + if err = checkNativeIterable(t, 2, Complex128); err != nil { + return nil, err + } + + data := t.Complex128s() + shape := t.Shape() + strides := t.Strides() + + rows := shape[0] + cols := shape[1] + rowStride := strides[0] + retVal = make([][]complex128, rows) + for i := range retVal { + start := i * rowStride + retVal[i] = make([]complex128, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + return +} + +// nativeDenseTensor3C128 converts a *Dense into a [][][]complex128. +// If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. +func nativeDenseTensor3C128(t *Dense) (retVal [][][]complex128, err error) { + if err = checkNativeIterable(t, 3, Complex128); err != nil { + return nil, err + } + + data := t.Complex128s() + shape := t.Shape() + strides := t.Strides() + + layers := shape[0] + rows := shape[1] + cols := shape[2] + layerStride := strides[0] + rowStride := strides[1] + retVal = make([][][]complex128, layers) + for i := range retVal { + retVal[i] = make([][]complex128, rows) + for j := range retVal[i] { + retVal[i][j] = make([]complex128, 0) + start := i*layerStride + j*rowStride + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + } + return +} + +/* Native Iterables for string */ + +// nativeDenseVectorStr converts a *Dense into a []string +// If the *Dense does not represent a vector of the wanted type, it will return +// an error. +func nativeDenseVectorStr(t *Dense) (retVal []string, err error) { + if err = checkNativeIterable(t, 1, String); err != nil { + return nil, err + } + return t.Strings(), nil +} + +// nativeDenseMatrixStr converts a *Dense into a [][]string +// If the *Dense does not represent a matrix of the wanted type, it +// will return an error. +func nativeDenseMatrixStr(t *Dense) (retVal [][]string, err error) { + if err = checkNativeIterable(t, 2, String); err != nil { + return nil, err + } + + data := t.Strings() + shape := t.Shape() + strides := t.Strides() + + rows := shape[0] + cols := shape[1] + rowStride := strides[0] + retVal = make([][]string, rows) + for i := range retVal { + start := i * rowStride + retVal[i] = make([]string, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + return +} + +// nativeDenseTensor3Str converts a *Dense into a [][][]string. +// If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. +func nativeDenseTensor3Str(t *Dense) (retVal [][][]string, err error) { + if err = checkNativeIterable(t, 3, String); err != nil { + return nil, err + } + + data := t.Strings() + shape := t.Shape() + strides := t.Strides() + + layers := shape[0] + rows := shape[1] + cols := shape[2] + layerStride := strides[0] + rowStride := strides[1] + retVal = make([][][]string, layers) + for i := range retVal { + retVal[i] = make([][]string, rows) + for j := range retVal[i] { + retVal[i][j] = make([]string, 0) + start := i*layerStride + j*rowStride + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + } + return +} diff --git a/iterator_native_test.go b/iterator_native_test.go new file mode 100644 index 0000000..afcd14d --- /dev/null +++ b/iterator_native_test.go @@ -0,0 +1,633 @@ +package tensor + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +// Code generated by genlib2. DO NOT EDIT. + +func Test_nativeDenseVectorB(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(Of(Bool), WithShape(6)) + it, err := nativeDenseVectorB(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(6, len(it)) +} + +func Test_nativeDenseMatrixB(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(Of(Bool), WithShape(2, 3)) + it, err := nativeDenseMatrixB(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) +} + +func Test_nativeDenseTensor3B(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(Of(Bool), WithShape(2, 3, 4)) + it, err := nativeDenseTensor3B(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) + assert.Equal(4, len(it[0][0])) +} + +func Test_nativeDenseVectorI(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Int, 0, 6)), WithShape(6)) + it, err := nativeDenseVectorI(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(6, len(it)) +} + +func Test_nativeDenseMatrixI(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Int, 0, 6)), WithShape(2, 3)) + it, err := nativeDenseMatrixI(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) +} + +func Test_nativeDenseTensor3I(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Int, 0, 24)), WithShape(2, 3, 4)) + it, err := nativeDenseTensor3I(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) + assert.Equal(4, len(it[0][0])) +} + +func Test_nativeDenseVectorI8(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Int8, 0, 6)), WithShape(6)) + it, err := nativeDenseVectorI8(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(6, len(it)) +} + +func Test_nativeDenseMatrixI8(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Int8, 0, 6)), WithShape(2, 3)) + it, err := nativeDenseMatrixI8(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) +} + +func Test_nativeDenseTensor3I8(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Int8, 0, 24)), WithShape(2, 3, 4)) + it, err := nativeDenseTensor3I8(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) + assert.Equal(4, len(it[0][0])) +} + +func Test_nativeDenseVectorI16(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Int16, 0, 6)), WithShape(6)) + it, err := nativeDenseVectorI16(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(6, len(it)) +} + +func Test_nativeDenseMatrixI16(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Int16, 0, 6)), WithShape(2, 3)) + it, err := nativeDenseMatrixI16(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) +} + +func Test_nativeDenseTensor3I16(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Int16, 0, 24)), WithShape(2, 3, 4)) + it, err := nativeDenseTensor3I16(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) + assert.Equal(4, len(it[0][0])) +} + +func Test_nativeDenseVectorI32(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Int32, 0, 6)), WithShape(6)) + it, err := nativeDenseVectorI32(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(6, len(it)) +} + +func Test_nativeDenseMatrixI32(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Int32, 0, 6)), WithShape(2, 3)) + it, err := nativeDenseMatrixI32(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) +} + +func Test_nativeDenseTensor3I32(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Int32, 0, 24)), WithShape(2, 3, 4)) + it, err := nativeDenseTensor3I32(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) + assert.Equal(4, len(it[0][0])) +} + +func Test_nativeDenseVectorI64(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Int64, 0, 6)), WithShape(6)) + it, err := nativeDenseVectorI64(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(6, len(it)) +} + +func Test_nativeDenseMatrixI64(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Int64, 0, 6)), WithShape(2, 3)) + it, err := nativeDenseMatrixI64(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) +} + +func Test_nativeDenseTensor3I64(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Int64, 0, 24)), WithShape(2, 3, 4)) + it, err := nativeDenseTensor3I64(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) + assert.Equal(4, len(it[0][0])) +} + +func Test_nativeDenseVectorU(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Uint, 0, 6)), WithShape(6)) + it, err := nativeDenseVectorU(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(6, len(it)) +} + +func Test_nativeDenseMatrixU(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Uint, 0, 6)), WithShape(2, 3)) + it, err := nativeDenseMatrixU(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) +} + +func Test_nativeDenseTensor3U(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Uint, 0, 24)), WithShape(2, 3, 4)) + it, err := nativeDenseTensor3U(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) + assert.Equal(4, len(it[0][0])) +} + +func Test_nativeDenseVectorU8(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Uint8, 0, 6)), WithShape(6)) + it, err := nativeDenseVectorU8(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(6, len(it)) +} + +func Test_nativeDenseMatrixU8(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Uint8, 0, 6)), WithShape(2, 3)) + it, err := nativeDenseMatrixU8(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) +} + +func Test_nativeDenseTensor3U8(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Uint8, 0, 24)), WithShape(2, 3, 4)) + it, err := nativeDenseTensor3U8(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) + assert.Equal(4, len(it[0][0])) +} + +func Test_nativeDenseVectorU16(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Uint16, 0, 6)), WithShape(6)) + it, err := nativeDenseVectorU16(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(6, len(it)) +} + +func Test_nativeDenseMatrixU16(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Uint16, 0, 6)), WithShape(2, 3)) + it, err := nativeDenseMatrixU16(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) +} + +func Test_nativeDenseTensor3U16(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Uint16, 0, 24)), WithShape(2, 3, 4)) + it, err := nativeDenseTensor3U16(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) + assert.Equal(4, len(it[0][0])) +} + +func Test_nativeDenseVectorU32(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Uint32, 0, 6)), WithShape(6)) + it, err := nativeDenseVectorU32(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(6, len(it)) +} + +func Test_nativeDenseMatrixU32(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Uint32, 0, 6)), WithShape(2, 3)) + it, err := nativeDenseMatrixU32(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) +} + +func Test_nativeDenseTensor3U32(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Uint32, 0, 24)), WithShape(2, 3, 4)) + it, err := nativeDenseTensor3U32(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) + assert.Equal(4, len(it[0][0])) +} + +func Test_nativeDenseVectorU64(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Uint64, 0, 6)), WithShape(6)) + it, err := nativeDenseVectorU64(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(6, len(it)) +} + +func Test_nativeDenseMatrixU64(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Uint64, 0, 6)), WithShape(2, 3)) + it, err := nativeDenseMatrixU64(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) +} + +func Test_nativeDenseTensor3U64(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Uint64, 0, 24)), WithShape(2, 3, 4)) + it, err := nativeDenseTensor3U64(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) + assert.Equal(4, len(it[0][0])) +} + +func Test_nativeDenseVectorF32(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Float32, 0, 6)), WithShape(6)) + it, err := nativeDenseVectorF32(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(6, len(it)) +} + +func Test_nativeDenseMatrixF32(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Float32, 0, 6)), WithShape(2, 3)) + it, err := nativeDenseMatrixF32(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) +} + +func Test_nativeDenseTensor3F32(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Float32, 0, 24)), WithShape(2, 3, 4)) + it, err := nativeDenseTensor3F32(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) + assert.Equal(4, len(it[0][0])) +} + +func Test_nativeDenseVectorF64(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Float64, 0, 6)), WithShape(6)) + it, err := nativeDenseVectorF64(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(6, len(it)) +} + +func Test_nativeDenseMatrixF64(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Float64, 0, 6)), WithShape(2, 3)) + it, err := nativeDenseMatrixF64(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) +} + +func Test_nativeDenseTensor3F64(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Float64, 0, 24)), WithShape(2, 3, 4)) + it, err := nativeDenseTensor3F64(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) + assert.Equal(4, len(it[0][0])) +} + +func Test_nativeDenseVectorC64(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Complex64, 0, 6)), WithShape(6)) + it, err := nativeDenseVectorC64(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(6, len(it)) +} + +func Test_nativeDenseMatrixC64(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Complex64, 0, 6)), WithShape(2, 3)) + it, err := nativeDenseMatrixC64(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) +} + +func Test_nativeDenseTensor3C64(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Complex64, 0, 24)), WithShape(2, 3, 4)) + it, err := nativeDenseTensor3C64(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) + assert.Equal(4, len(it[0][0])) +} + +func Test_nativeDenseVectorC128(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Complex128, 0, 6)), WithShape(6)) + it, err := nativeDenseVectorC128(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(6, len(it)) +} + +func Test_nativeDenseMatrixC128(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Complex128, 0, 6)), WithShape(2, 3)) + it, err := nativeDenseMatrixC128(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) +} + +func Test_nativeDenseTensor3C128(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Complex128, 0, 24)), WithShape(2, 3, 4)) + it, err := nativeDenseTensor3C128(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) + assert.Equal(4, len(it[0][0])) +} + +func Test_nativeDenseVectorStr(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(Of(String), WithShape(6)) + it, err := nativeDenseVectorStr(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(6, len(it)) +} + +func Test_nativeDenseMatrixStr(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(Of(String), WithShape(2, 3)) + it, err := nativeDenseMatrixStr(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) +} + +func Test_nativeDenseTensor3Str(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(Of(String), WithShape(2, 3, 4)) + it, err := nativeDenseTensor3Str(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) + assert.Equal(4, len(it[0][0])) +} diff --git a/iterator_test.go b/iterator_test.go index 1d7f170..d0ca6de 100644 --- a/iterator_test.go +++ b/iterator_test.go @@ -6,6 +6,12 @@ import ( "github.com/stretchr/testify/assert" ) +// newAP is a helper function now +func newAP(shape Shape, strides []int) *AP { + ap := MakeAP(shape, strides, 0, 0) + return &ap +} + var flatIterTests1 = []struct { shape Shape strides []int @@ -14,8 +20,8 @@ var flatIterTests1 = []struct { }{ {ScalarShape(), []int{}, []int{0}}, // scalar {Shape{5}, []int{1}, []int{0, 1, 2, 3, 4}}, // vector - {Shape{5, 1}, []int{1}, []int{0, 1, 2, 3, 4}}, // colvec - {Shape{1, 5}, []int{1}, []int{0, 1, 2, 3, 4}}, // rowvec + {Shape{5, 1}, []int{1, 1}, []int{0, 1, 2, 3, 4}}, // colvec + {Shape{1, 5}, []int{5, 1}, []int{0, 1, 2, 3, 4}}, // rowvec {Shape{2, 3}, []int{3, 1}, []int{0, 1, 2, 3, 4, 5}}, // basic mat {Shape{3, 2}, []int{1, 3}, []int{0, 3, 1, 4, 2, 5}}, // basic mat, transposed {Shape{2}, []int{2}, []int{0, 2}}, // basic 2x2 mat, sliced: Mat[:, 1] @@ -27,6 +33,11 @@ var flatIterTests1 = []struct { {Shape{4, 2, 3}, []int{1, 12, 4}, []int{0, 4, 8, 12, 16, 20, 1, 5, 9, 13, 17, 21, 2, 6, 10, 14, 18, 22, 3, 7, 11, 15, 19, 23}}, // basic 3-Tensor (under (2, 0, 1) transpose) {Shape{3, 2, 4}, []int{4, 12, 1}, []int{0, 1, 2, 3, 12, 13, 14, 15, 4, 5, 6, 7, 16, 17, 18, 19, 8, 9, 10, 11, 20, 21, 22, 23}}, // basic 3-Tensor (under (1, 0, 2) transpose) {Shape{4, 3, 2}, []int{1, 4, 12}, []int{0, 12, 4, 16, 8, 20, 1, 13, 5, 17, 9, 21, 2, 14, 6, 18, 10, 22, 3, 15, 7, 19, 11, 23}}, // basic 3-Tensor (under (2, 1, 0) transpose) + + // ARTIFICIAL CASES - TODO + // These cases should be impossible to reach in normal operation + // You would have to specially construct these + // {Shape{1, 5}, []int{1}, []int{0, 1, 2, 3, 4}}, // rowvec - NEARLY IMPOSSIBLE CASE- TODO } var flatIterSlices = []struct { @@ -49,8 +60,8 @@ func TestFlatIterator(t *testing.T) { for i, fit := range flatIterTests1 { nexts = nexts[:0] err = nil - ap = NewAP(fit.shape, fit.strides) - it = NewFlatIterator(ap) + ap = newAP(fit.shape, fit.strides) + it = newFlatIterator(ap) for next, err := it.Next(); err == nil; next, err = it.Next() { nexts = append(nexts, next) } @@ -73,8 +84,8 @@ func TestFlatIteratorReverse(t *testing.T) { for i, fit := range flatIterTests1 { nexts = nexts[:0] err = nil - ap = NewAP(fit.shape, fit.strides) - it = NewFlatIterator(ap) + ap = newAP(fit.shape, fit.strides) + it = newFlatIterator(ap) it.SetReverse() for next, err := it.Next(); err == nil; next, err = it.Next() { nexts = append(nexts, next) @@ -108,7 +119,7 @@ func TestMultIterator(t *testing.T) { for i, fit := range flatIterTests1 { nexts[0] = nexts[0][:0] err = nil - ap[0] = NewAP(fit.shape, fit.strides) + ap[0] = newAP(fit.shape, fit.strides) it = NewMultIterator(ap[0]) if reverse { it.SetReverse() @@ -124,43 +135,45 @@ func TestMultIterator(t *testing.T) { nexts[0][i], nexts[0][j] = nexts[0][j], nexts[0][i] } } - assert.Equal(fit.correct, nexts[0], "Repeating flat test %d", i) + assert.Equal(fit.correct, nexts[0], "Repeating flat test %d. Reverse? %v", i, reverse) } // Test multiple iterators simultaneously - var choices = []int{0, 0, 9, 9, 0, 9} - for j := 0; j < 6; j++ { - fit := flatIterTests1[choices[j]] - nexts[j] = nexts[j][:0] - err = nil - ap[j] = NewAP(fit.shape, fit.strides) - } - it = NewMultIterator(ap...) - if reverse { - it.SetReverse() - } - for _, err := it.Next(); err == nil; _, err = it.Next() { + /* + var choices = []int{0, 0, 9, 9, 0, 9} for j := 0; j < 6; j++ { - nexts[j] = append(nexts[j], it.LastIndex(j)) + fit := flatIterTests1[choices[j]] + nexts[j] = nexts[j][:0] + err = nil + ap[j] = newAP(fit.shape, fit.strides) } - - if _, ok := err.(NoOpError); err != nil && !ok { - t.Error(err) + it = NewMultIterator(ap...) + if reverse { + it.SetReverse() } - } + for _, err := it.Next(); err == nil; _, err = it.Next() { + for j := 0; j < 6; j++ { + nexts[j] = append(nexts[j], it.LastIndex(j)) + } - for j := 0; j < 6; j++ { - fit := flatIterTests1[choices[j]] - if reverse { - for i, k := 0, len(nexts[j])-1; i < k; i, k = i+1, k-1 { - nexts[j][i], nexts[j][k] = nexts[j][k], nexts[j][i] + if _, ok := err.(NoOpError); err != nil && !ok { + t.Error(err) } } - if ap[j].IsScalar() { - assert.Equal(fit.correct, nexts[j][:1], "Test multiple iterators %d", j) - } else { - assert.Equal(fit.correct, nexts[j], "Test multiple iterators %d", j) + + for j := 0; j < 6; j++ { + fit := flatIterTests1[choices[j]] + if reverse { + for i, k := 0, len(nexts[j])-1; i < k; i, k = i+1, k-1 { + nexts[j][i], nexts[j][k] = nexts[j][k], nexts[j][i] + } + } + if ap[j].IsScalar() { + assert.Equal(fit.correct, nexts[j][:1], "Test multiple iterators %d", j) + } else { + assert.Equal(fit.correct, nexts[j], "Test multiple iterators %d", j) + } } - } + */ } } @@ -177,7 +190,7 @@ func TestIteratorInterface(t *testing.T) { for i, fit := range flatIterTests1 { nexts = nexts[:0] err = nil - ap = NewAP(fit.shape, fit.strides) + ap = newAP(fit.shape, fit.strides) it = NewIterator(ap) for next, err := it.Start(); err == nil; next, err = it.Next() { nexts = append(nexts, next) @@ -223,8 +236,8 @@ func TestFlatIterator_Chan(t *testing.T) { // basic stuff for i, fit := range flatIterTests1 { nexts = nexts[:0] - ap = NewAP(fit.shape, fit.strides) - it = NewFlatIterator(ap) + ap = newAP(fit.shape, fit.strides) + it = newFlatIterator(ap) ch := it.Chan() for next := range ch { nexts = append(nexts, next) @@ -242,8 +255,8 @@ func TestFlatIterator_Slice(t *testing.T) { var nexts []int for i, fit := range flatIterTests1 { - ap = NewAP(fit.shape, fit.strides) - it = NewFlatIterator(ap) + ap = newAP(fit.shape, fit.strides) + it = newFlatIterator(ap) nexts, err = it.Slice(nil) if _, ok := err.(NoOpError); err != nil && !ok { t.Error(err) @@ -276,8 +289,8 @@ func TestFlatIterator_Coord(t *testing.T) { // var nexts []int var donecount int - ap = NewAP(Shape{2, 3, 4}, []int{12, 4, 1}) - it = NewFlatIterator(ap) + ap = newAP(Shape{2, 3, 4}, []int{12, 4, 1}) + it = newFlatIterator(ap) var correct = [][]int{ {0, 0, 1}, @@ -315,8 +328,8 @@ func TestFlatIterator_Coord(t *testing.T) { // really this is just for completeness sake func TestFlatIterator_Reset(t *testing.T) { assert := assert.New(t) - ap := NewAP(Shape{2, 3, 4}, []int{12, 4, 1}) - it := NewFlatIterator(ap) + ap := newAP(Shape{2, 3, 4}, []int{12, 4, 1}) + it := newFlatIterator(ap) it.Next() it.Next() @@ -349,7 +362,7 @@ type oldFlatIterator struct { done bool } -// NewFlatIterator creates a new FlatIterator +// newFlatIterator creates a new FlatIterator func newOldFlatIterator(ap *AP) *oldFlatIterator { return &oldFlatIterator{ AP: ap, @@ -406,7 +419,7 @@ func BenchmarkOldFlatIterator(b *testing.B) { // as if T = NewTensor(WithShape(30, 1000, 1000)) // then T[:, 0:900:15, 250:750:50] - ap := NewAP(Shape{30, 60, 10}, []int{1000000, 15000, 50}) + ap := newAP(Shape{30, 60, 10}, []int{1000000, 15000, 50}) it := newOldFlatIterator(ap) for n := 0; n < b.N; n++ { @@ -426,8 +439,8 @@ func BenchmarkFlatIterator(b *testing.B) { // as if T = NewTensor(WithShape(30, 1000, 1000)) // then T[:, 0:900:15, 250:750:50] - ap := NewAP(Shape{30, 60, 10}, []int{1000000, 15000, 50}) - it := NewFlatIterator(ap) + ap := newAP(Shape{30, 60, 10}, []int{1000000, 15000, 50}) + it := newFlatIterator(ap) for n := 0; n < b.N; n++ { for _, err := it.Next(); err == nil; _, err = it.Next() { @@ -450,8 +463,8 @@ func BenchmarkFlatIteratorParallel6(b *testing.B) { it := make([]*FlatIterator, 6) for j := 0; j < 6; j++ { - ap[j] = NewAP(Shape{30, 60, 10}, []int{1000000, 15000, 50}) - it[j] = NewFlatIterator(ap[j]) + ap[j] = newAP(Shape{30, 60, 10}, []int{1000000, 15000, 50}) + it[j] = newFlatIterator(ap[j]) } for n := 0; n < b.N; n++ { @@ -476,7 +489,7 @@ func BenchmarkFlatIteratorMulti1(b *testing.B) { // as if T = NewTensor(WithShape(30, 1000, 1000)) // then T[:, 0:900:15, 250:750:50] - ap := NewAP(Shape{30, 60, 10}, []int{1000000, 15000, 50}) + ap := newAP(Shape{30, 60, 10}, []int{1000000, 15000, 50}) it := NewMultIterator(ap) @@ -496,7 +509,7 @@ func BenchmarkFlatIteratorGeneric1(b *testing.B) { // as if T = NewTensor(WithShape(30, 1000, 1000)) // then T[:, 0:900:15, 250:750:50] - ap := NewAP(Shape{30, 60, 10}, []int{1000000, 15000, 50}) + ap := newAP(Shape{30, 60, 10}, []int{1000000, 15000, 50}) it := NewIterator(ap) @@ -519,7 +532,7 @@ func BenchmarkFlatIteratorMulti6(b *testing.B) { ap := make([]*AP, 6) for j := 0; j < 6; j++ { - ap[j] = NewAP(Shape{30, 60, 10}, []int{1000000, 15000, 50}) + ap[j] = newAP(Shape{30, 60, 10}, []int{1000000, 15000, 50}) } it := NewMultIterator(ap...) diff --git a/junkyard_test.go b/junkyard_test.go index 428178a..6d4b43a 100644 --- a/junkyard_test.go +++ b/junkyard_test.go @@ -9,7 +9,7 @@ import ( func TestRandom(t *testing.T) { const size = 50 - for _, typ := range numberTypes.set { + for _, typ := range numberTypes { r := Random(typ, size) typR := reflect.TypeOf(r).Elem() diff --git a/known_issues_test.go b/known_issues_test.go new file mode 100644 index 0000000..3175ce7 --- /dev/null +++ b/known_issues_test.go @@ -0,0 +1,152 @@ +package tensor + +import ( + "testing" + "testing/quick" + + "github.com/stretchr/testify/assert" + "gorgonia.org/dtype" +) + +func TestIssue70(t *testing.T) { + a := 2.0 + b := NewDense(Float64, Shape{1, 1}, WithBacking([]float64{3})) + var correct interface{} = []float64{6.0} + + res, err := Mul(a, b) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) + t.Logf("a %v b %v, res %v", a, b, res) +} + +func TestIssue72(t *testing.T) { + a := New(FromScalar(3.14)) + b := 0.0 + + bsa, err := Sub(b, a) + if err != nil { + t.Fatal(err) + } + t.Logf("%v", bsa) + ret, err := Sub(b, bsa, UseUnsafe()) + if err != nil { + t.Fatal(err) + } + t.Logf("%v %v", ret, bsa) + + invReuseScalar := func(q *Dense) bool { + a := q.Clone().(*Dense) + //if !a.Shape().IsScalarEquiv() { + // return true + //} + b := identityVal(0, q.t) + reuse := New(Of(a.t), WithShape(a.Shape().Clone()...)) + correct := a.Clone().(*Dense) + we, willFailEq := willerr(a, dtype.Number, dtype.Unsigned) + _, ok := q.Engine().(Suber) + we = we || !ok + //log.Printf("b-a(r) | b:%v, a %v, r %v", b, a.Shape(), reuse.Shape()) + + ret, err := Sub(b, a, WithReuse(reuse)) + if err, retEarly := qcErrCheck(t, "SubSV", a, b, we, err); retEarly { + if err != nil { + t.Logf("err %v", err) + return false + } + return true + } + //log.Printf("b-a(r) | b:%v, a %v, r %v, ret %v", b, a.Shape(), reuse.Shape(), ret.Shape()) + ret, err = Sub(b, ret, UseUnsafe()) + + if !qcEqCheck(t, a.Dtype(), willFailEq, correct.Data(), ret.Data()) { + t.Errorf("a %v ", a.Shape()) + return false + } + if reuse != ret { + t.Errorf("Expected reuse to be the same as retVal") + return false + } + + return true + } + if err := quick.Check(invReuseScalar, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { + t.Errorf("Inv test for Sub (scalar as left, tensor as right) failed: %v", err) + } + +} + +func TestIssue83(t *testing.T) { + backing := []float64{-1, 0, 1} + var TT Tensor + TT = New( + WithShape(1, 3), + WithBacking(backing)) + TT, _ = T(TT) + + it := IteratorFromDense(TT.(*Dense)) + for i, ierr := it.Next(); ierr == nil; i, ierr = it.Next() { + if ierr != nil { + t.Error(ierr) + } + if i >= len(backing) { + t.Errorf("Iterator should not return an `i` greater than %v", i) + } + } + + backing = []float64{1, 2, 3, 4, 5, 5, 4, 3, 2, 1} + TT = New(WithShape(10, 1, 1, 1), WithBacking(backing)) + it = IteratorFromDense(TT.(*Dense)) + + var vals []float64 + for i, ierr := it.Next(); ierr == nil; i, ierr = it.Next() { + if ierr != nil { + t.Error(ierr) + } + v := TT.Data().([]float64)[i] + vals = append(vals, v) + } + t.Logf("%v", vals) + +} + +func TestIssue88(t *testing.T) { + a := New(WithShape(4, 2), WithBacking([]float64{1, 1, 1, 1, 1, 1, 1, 1})) + b := New(WithShape(2, 4), WithBacking([]float64{0, 1, 0, 1, 0, 1, 0, 1})) + c, _ := a.MatMul(b) + _, err := Div(c, 2) + if err == nil { + t.Fatal("Expected an error") + } +} + +var ltoiTestCases = []struct { + name string + shape Shape + strides []int + coordinates []int + correct int + willErr bool +}{ + {"\"scalar\" - scalarshape", Shape{}, nil, []int{0}, 0, false}, + {"\"scalar\" - scalarshape, non empty strides", Shape{}, []int{1}, []int{0}, 0, false}, + {"\"scalar\" - scalarlike", Shape{1, 1, 1}, []int{1, 1, 1}, []int{0, 0, 0}, 0, false}, + {"vector", Shape{10}, []int{1}, []int{1}, 1, false}, + {"rowvec", Shape{1, 10}, []int{10, 1}, []int{0, 1}, 1, false}, + {"colvec", Shape{10, 1}, []int{1, 1}, []int{1, 0}, 1, false}, + {"rowvec- funny strides", Shape{1, 10}, []int{1}, []int{0, 1}, 1, false}, + {"colvec - funny strides", Shape{10, 1}, []int{1}, []int{1, 0}, 1, false}, +} + +func TestIssue90(t *testing.T) { + for i, c := range ltoiTestCases { + at, err := Ltoi(c.shape, c.strides, c.coordinates...) + if !checkErr(t, c.willErr, err, c.name, i) { + continue + } + if at != c.correct { + t.Errorf("Expected Ltoi(%v, %v, %v) to be %v. Got %v instead", c.shape, c.strides, c.coordinates, c.correct, at) + } + } +} diff --git a/known_race_test.go b/known_race_test.go new file mode 100644 index 0000000..f6d5616 --- /dev/null +++ b/known_race_test.go @@ -0,0 +1,34 @@ +// +build ignore +// +build !race + +package tensor + +import ( + "testing" + "unsafe" + + "github.com/stretchr/testify/assert" +) + +// This test will fail the `go test -race`. +// +// This is because FromMemory() will use uintptr in a way that is incorrect according to the checkptr directive of Go 1.14+ +// +// Though it's incorrect, it's the only way to use heterogenous, readable memory (i.e. CUDA). +func TestFromMemory(t *testing.T) { + // dummy memory - this could be an externally malloc'd memory, or a mmap'ed file. + // but here we're just gonna let Go manage memory. + s := make([]float64, 100) + ptr := uintptr(unsafe.Pointer(&s[0])) + size := uintptr(100 * 8) + + T := New(Of(Float32), WithShape(50, 4), FromMemory(ptr, size)) + if len(T.Float32s()) != 200 { + t.Error("expected 200 Float32s") + } + assert.Equal(t, make([]float32, 200), T.Data()) + assert.True(t, T.IsManuallyManaged(), "Unamanged %v |%v | q: %v", ManuallyManaged, T.flag, (T.flag>>ManuallyManaged)&MemoryFlag(1)) + + fail := func() { New(FromMemory(ptr, size), Of(Float32)) } + assert.Panics(t, fail, "Expected bad New() call to panic") +} diff --git a/mathutils.go b/mathutils.go index 88ebbae..8060ad3 100644 --- a/mathutils.go +++ b/mathutils.go @@ -1,4 +1,4 @@ -// +build !noasm +// +build amd64,!noasm package tensor diff --git a/mathutils_go.go b/mathutils_go.go index 1a5f2c1..299bcbe 100644 --- a/mathutils_go.go +++ b/mathutils_go.go @@ -1,4 +1,4 @@ -// +build noasm +// +build !amd64 noasm package tensor diff --git a/native/doc.go b/native/doc.go new file mode 100644 index 0000000..516fbe2 --- /dev/null +++ b/native/doc.go @@ -0,0 +1,8 @@ +// package native is a utility package for gorgonia.org/tensor. +// +// Amongst other things, it provides iterators that use Go slice semantics, while keeping a reference to the underlying memory. +// This means you can update the slices and the changes will be reflected back into the original tensor. +// +// There is of course a cost of using the native iterators and selectors - allocation costs. +// For best performance, don't use these in a tight loop. +package native diff --git a/native/example_test.go b/native/example_test.go new file mode 100644 index 0000000..740d103 --- /dev/null +++ b/native/example_test.go @@ -0,0 +1,81 @@ +package native + +import ( + "fmt" + + . "gorgonia.org/tensor" +) + +// There are times where it is more effective to use native Go slice semantics to do work (for example, when performing batch work over kernels). +// Iterators are useful for this purpose. This package provides iterators for the standard types +// However, custom types are also available. See Vector, Matrix and Tensor3 examples. +func Example_iterator() { + var T *Dense + T = New(WithShape(2, 3), WithBacking(Range(Float64, 0, 6))) + x, err := MatrixF64(T) + if err != nil { + fmt.Printf("ERR: %v", err) + } + + for _, row := range x { + fmt.Printf("%v\n", row) + } + + // Output: + // [0 1 2] + // [3 4 5] +} + +// The NativeSelect function squashes the dimensions, and returns an iterator in native Go slice semantics. +func Example_select() { + // Selection is a bit of an interesting use case. Sometimes you don't want to iterate through the layers. + // + // For example, in a number of use cases where you have a 4-Tensor, you'd typically reshape it to some + // 2D matrix which can then be plugged into BLAS algorithms directly. Sometimes you wouldn't need to reshape. + // All you have to do is squash the dimensions inwards. This function does that. + // + // The best way to explain the Select functions is through concrete examples. + // Imagine a tensor with (2,3,4,5) shape. Arbitrarily, we call them (NCHW) - Batch Size, Channel Count, Height, Width. + // If we want to select all the channels, across all batches, then `NativeSelectX(T, 1)` would yield all channels. The resulting matrix will be (6, 20) + // If we want to select all the heights, across all channels and batches, then `NativeSelectX(T, 2) will yield all heights. The resulting matrix will be (24, 5) + // + // If for some reason the format was in NHWC, then you would need to reshape. This wouldn't be useful. + + var T *Dense + T = New(WithShape(2, 3, 4, 5), WithBacking(Range(Float64, 0, 2*3*4*5))) + x, err := SelectF64(T, 1) + if err != nil { + fmt.Printf("ERR %v", err) + } + for _, row := range x { + fmt.Printf("%3.0f\n", row) + } + + // Output: + // [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19] + // [ 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39] + // [ 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59] + // [ 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79] + // [ 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99] + // [100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119] +} + +// The iterators are iteratos in the truest sense. The data isn't copied, as this example shows +func Example_clobber() { + var T *Dense + T = New(WithShape(2, 3), WithBacking(Range(Float64, 0, 6))) + fmt.Printf("Before :\n%v", T) + + xx, _ := MatrixF64(T) + xx[1][1] = 10000 + fmt.Printf("After :\n%v", T) + + // Output: + // Before : + // ⎡0 1 2⎤ + // ⎣3 4 5⎦ + // After : + // ⎡ 0 1 2⎤ + // ⎣ 3 10000 5⎦ + +} diff --git a/native/generic.go b/native/generic.go new file mode 100644 index 0000000..79d8dc3 --- /dev/null +++ b/native/generic.go @@ -0,0 +1,72 @@ +package native + +import ( + "reflect" + "unsafe" + + . "gorgonia.org/tensor" +) + +func Vector(t *Dense) (interface{}, error) { + if err := checkNativeIterable(t, 1, t.Dtype()); err != nil { + return nil, err + } + return t.Data(), nil +} + +func Matrix(t *Dense) (interface{}, error) { + if err := checkNativeIterable(t, 2, t.Dtype()); err != nil { + return nil, err + } + + shape := t.Shape() + strides := t.Strides() + typ := t.Dtype().Type + rows := shape[0] + cols := shape[1] + rowStride := strides[0] + + retVal := reflect.MakeSlice(reflect.SliceOf(reflect.SliceOf(typ)), rows, rows) + ptr := t.Uintptr() + for i := 0; i < rows; i++ { + e := retVal.Index(i) + sh := (*reflect.SliceHeader)(unsafe.Pointer(e.Addr().Pointer())) + sh.Data = uintptr(i*rowStride)*typ.Size() + ptr + sh.Len = cols + sh.Cap = cols + } + return retVal.Interface(), nil +} + +func Tensor3(t *Dense) (interface{}, error) { + if err := checkNativeIterable(t, 3, t.Dtype()); err != nil { + return nil, err + } + shape := t.Shape() + strides := t.Strides() + typ := t.Dtype().Type + + layers := shape[0] + rows := shape[1] + cols := shape[2] + layerStride := strides[0] + rowStride := strides[1] + retVal := reflect.MakeSlice(reflect.SliceOf(reflect.SliceOf(reflect.SliceOf(typ))), layers, layers) + ptr := t.Uintptr() + for i := 0; i < layers; i++ { + el := retVal.Index(i) + inner := reflect.MakeSlice(reflect.SliceOf(reflect.SliceOf(typ)), rows, rows) + for j := 0; j < rows; j++ { + e := inner.Index(j) + sh := (*reflect.SliceHeader)(unsafe.Pointer(e.Addr().Pointer())) + sh.Data = uintptr(i*layerStride+j*rowStride)*typ.Size() + ptr + sh.Len = cols + sh.Cap = cols + } + sh := (*reflect.SliceHeader)(unsafe.Pointer(el.Addr().Pointer())) + sh.Data = inner.Index(0).Addr().Pointer() + sh.Len = rows + sh.Cap = rows + } + return retVal.Interface(), nil +} diff --git a/native/generic_test.go b/native/generic_test.go new file mode 100644 index 0000000..cf09802 --- /dev/null +++ b/native/generic_test.go @@ -0,0 +1,67 @@ +package native_test + +import ( + "fmt" + + "gorgonia.org/tensor" + . "gorgonia.org/tensor/native" +) + +type MyType int + +func Example_vector() { + backing := []MyType{ + 0, 1, 2, 3, + } + T := tensor.New(tensor.WithShape(4), tensor.WithBacking(backing)) + val, err := Vector(T) + if err != nil { + fmt.Printf("error: %v", err) + } + it := val.([]MyType) + fmt.Println(it) + + // Output: + // [0 1 2 3] +} + +func Example_matrix() { + backing := []MyType{ + 0, 1, + 2, 3, + 4, 5, + } + T := tensor.New(tensor.WithShape(3, 2), tensor.WithBacking(backing)) + val, err := Matrix(T) + if err != nil { + fmt.Printf("error: %v", err) + } + + it := val.([][]MyType) + fmt.Println(it) + + // Output: + // [[0 1] [2 3] [4 5]] +} + +func Example_tensor3() { + backing := []MyType{ + 0, 1, 2, 3, + 4, 5, 6, 7, + 8, 9, 10, 11, + + 12, 13, 14, 15, + 16, 17, 18, 19, + 20, 21, 22, 23, + } + T := tensor.New(tensor.WithShape(2, 3, 4), tensor.WithBacking(backing)) + val, err := Tensor3(T) + if err != nil { + fmt.Printf("error: %v", err) + } + it := val.([][][]MyType) + fmt.Println(it) + + //Output: + // [[[0 1 2 3] [4 5 6 7] [8 9 10 11]] [[12 13 14 15] [16 17 18 19] [20 21 22 23]]] +} diff --git a/native/iterator_native.go b/native/iterator_native.go new file mode 100644 index 0000000..1ad0573 --- /dev/null +++ b/native/iterator_native.go @@ -0,0 +1,332 @@ +//go:build !purego +// +build !purego + +package native + +// Code generated by genlib2. DO NOT EDIT. + +import ( + _ "unsafe" + + "gorgonia.org/tensor" +) + +//go:linkname VectorB gorgonia.org/tensor.nativeDenseVectorB + +// VectorB converts a *Dense into a []bool +// If the *Dense does not represent a vector of the wanted type, it will return +// an error. +func VectorB(t *tensor.Dense) (retVal []bool, err error) + +//go:linkname MatrixB gorgonia.org/tensor.nativeDenseMatrixB + +// MatrixB converts a *Dense into a [][]bool +// If the *Dense does not represent a matrix of the wanted type, it +// will return an error. +func MatrixB(t *tensor.Dense) (retVal [][]bool, err error) + +//go:linkname Tensor3B gorgonia.org/tensor.nativeDenseTensor3B + +// Tensor3B converts a *Dense into a [][][]bool. +// If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. +func Tensor3B(t *tensor.Dense) (retVal [][][]bool, err error) + +//go:linkname VectorI gorgonia.org/tensor.nativeDenseVectorI + +// VectorI converts a *Dense into a []int +// If the *Dense does not represent a vector of the wanted type, it will return +// an error. +func VectorI(t *tensor.Dense) (retVal []int, err error) + +//go:linkname MatrixI gorgonia.org/tensor.nativeDenseMatrixI + +// MatrixI converts a *Dense into a [][]int +// If the *Dense does not represent a matrix of the wanted type, it +// will return an error. +func MatrixI(t *tensor.Dense) (retVal [][]int, err error) + +//go:linkname Tensor3I gorgonia.org/tensor.nativeDenseTensor3I + +// Tensor3I converts a *Dense into a [][][]int. +// If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. +func Tensor3I(t *tensor.Dense) (retVal [][][]int, err error) + +//go:linkname VectorI8 gorgonia.org/tensor.nativeDenseVectorI8 + +// VectorI8 converts a *Dense into a []int8 +// If the *Dense does not represent a vector of the wanted type, it will return +// an error. +func VectorI8(t *tensor.Dense) (retVal []int8, err error) + +//go:linkname MatrixI8 gorgonia.org/tensor.nativeDenseMatrixI8 + +// MatrixI8 converts a *Dense into a [][]int8 +// If the *Dense does not represent a matrix of the wanted type, it +// will return an error. +func MatrixI8(t *tensor.Dense) (retVal [][]int8, err error) + +//go:linkname Tensor3I8 gorgonia.org/tensor.nativeDenseTensor3I8 + +// Tensor3I8 converts a *Dense into a [][][]int8. +// If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. +func Tensor3I8(t *tensor.Dense) (retVal [][][]int8, err error) + +//go:linkname VectorI16 gorgonia.org/tensor.nativeDenseVectorI16 + +// VectorI16 converts a *Dense into a []int16 +// If the *Dense does not represent a vector of the wanted type, it will return +// an error. +func VectorI16(t *tensor.Dense) (retVal []int16, err error) + +//go:linkname MatrixI16 gorgonia.org/tensor.nativeDenseMatrixI16 + +// MatrixI16 converts a *Dense into a [][]int16 +// If the *Dense does not represent a matrix of the wanted type, it +// will return an error. +func MatrixI16(t *tensor.Dense) (retVal [][]int16, err error) + +//go:linkname Tensor3I16 gorgonia.org/tensor.nativeDenseTensor3I16 + +// Tensor3I16 converts a *Dense into a [][][]int16. +// If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. +func Tensor3I16(t *tensor.Dense) (retVal [][][]int16, err error) + +//go:linkname VectorI32 gorgonia.org/tensor.nativeDenseVectorI32 + +// VectorI32 converts a *Dense into a []int32 +// If the *Dense does not represent a vector of the wanted type, it will return +// an error. +func VectorI32(t *tensor.Dense) (retVal []int32, err error) + +//go:linkname MatrixI32 gorgonia.org/tensor.nativeDenseMatrixI32 + +// MatrixI32 converts a *Dense into a [][]int32 +// If the *Dense does not represent a matrix of the wanted type, it +// will return an error. +func MatrixI32(t *tensor.Dense) (retVal [][]int32, err error) + +//go:linkname Tensor3I32 gorgonia.org/tensor.nativeDenseTensor3I32 + +// Tensor3I32 converts a *Dense into a [][][]int32. +// If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. +func Tensor3I32(t *tensor.Dense) (retVal [][][]int32, err error) + +//go:linkname VectorI64 gorgonia.org/tensor.nativeDenseVectorI64 + +// VectorI64 converts a *Dense into a []int64 +// If the *Dense does not represent a vector of the wanted type, it will return +// an error. +func VectorI64(t *tensor.Dense) (retVal []int64, err error) + +//go:linkname MatrixI64 gorgonia.org/tensor.nativeDenseMatrixI64 + +// MatrixI64 converts a *Dense into a [][]int64 +// If the *Dense does not represent a matrix of the wanted type, it +// will return an error. +func MatrixI64(t *tensor.Dense) (retVal [][]int64, err error) + +//go:linkname Tensor3I64 gorgonia.org/tensor.nativeDenseTensor3I64 + +// Tensor3I64 converts a *Dense into a [][][]int64. +// If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. +func Tensor3I64(t *tensor.Dense) (retVal [][][]int64, err error) + +//go:linkname VectorU gorgonia.org/tensor.nativeDenseVectorU + +// VectorU converts a *Dense into a []uint +// If the *Dense does not represent a vector of the wanted type, it will return +// an error. +func VectorU(t *tensor.Dense) (retVal []uint, err error) + +//go:linkname MatrixU gorgonia.org/tensor.nativeDenseMatrixU + +// MatrixU converts a *Dense into a [][]uint +// If the *Dense does not represent a matrix of the wanted type, it +// will return an error. +func MatrixU(t *tensor.Dense) (retVal [][]uint, err error) + +//go:linkname Tensor3U gorgonia.org/tensor.nativeDenseTensor3U + +// Tensor3U converts a *Dense into a [][][]uint. +// If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. +func Tensor3U(t *tensor.Dense) (retVal [][][]uint, err error) + +//go:linkname VectorU8 gorgonia.org/tensor.nativeDenseVectorU8 + +// VectorU8 converts a *Dense into a []uint8 +// If the *Dense does not represent a vector of the wanted type, it will return +// an error. +func VectorU8(t *tensor.Dense) (retVal []uint8, err error) + +//go:linkname MatrixU8 gorgonia.org/tensor.nativeDenseMatrixU8 + +// MatrixU8 converts a *Dense into a [][]uint8 +// If the *Dense does not represent a matrix of the wanted type, it +// will return an error. +func MatrixU8(t *tensor.Dense) (retVal [][]uint8, err error) + +//go:linkname Tensor3U8 gorgonia.org/tensor.nativeDenseTensor3U8 + +// Tensor3U8 converts a *Dense into a [][][]uint8. +// If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. +func Tensor3U8(t *tensor.Dense) (retVal [][][]uint8, err error) + +//go:linkname VectorU16 gorgonia.org/tensor.nativeDenseVectorU16 + +// VectorU16 converts a *Dense into a []uint16 +// If the *Dense does not represent a vector of the wanted type, it will return +// an error. +func VectorU16(t *tensor.Dense) (retVal []uint16, err error) + +//go:linkname MatrixU16 gorgonia.org/tensor.nativeDenseMatrixU16 + +// MatrixU16 converts a *Dense into a [][]uint16 +// If the *Dense does not represent a matrix of the wanted type, it +// will return an error. +func MatrixU16(t *tensor.Dense) (retVal [][]uint16, err error) + +//go:linkname Tensor3U16 gorgonia.org/tensor.nativeDenseTensor3U16 + +// Tensor3U16 converts a *Dense into a [][][]uint16. +// If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. +func Tensor3U16(t *tensor.Dense) (retVal [][][]uint16, err error) + +//go:linkname VectorU32 gorgonia.org/tensor.nativeDenseVectorU32 + +// VectorU32 converts a *Dense into a []uint32 +// If the *Dense does not represent a vector of the wanted type, it will return +// an error. +func VectorU32(t *tensor.Dense) (retVal []uint32, err error) + +//go:linkname MatrixU32 gorgonia.org/tensor.nativeDenseMatrixU32 + +// MatrixU32 converts a *Dense into a [][]uint32 +// If the *Dense does not represent a matrix of the wanted type, it +// will return an error. +func MatrixU32(t *tensor.Dense) (retVal [][]uint32, err error) + +//go:linkname Tensor3U32 gorgonia.org/tensor.nativeDenseTensor3U32 + +// Tensor3U32 converts a *Dense into a [][][]uint32. +// If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. +func Tensor3U32(t *tensor.Dense) (retVal [][][]uint32, err error) + +//go:linkname VectorU64 gorgonia.org/tensor.nativeDenseVectorU64 + +// VectorU64 converts a *Dense into a []uint64 +// If the *Dense does not represent a vector of the wanted type, it will return +// an error. +func VectorU64(t *tensor.Dense) (retVal []uint64, err error) + +//go:linkname MatrixU64 gorgonia.org/tensor.nativeDenseMatrixU64 + +// MatrixU64 converts a *Dense into a [][]uint64 +// If the *Dense does not represent a matrix of the wanted type, it +// will return an error. +func MatrixU64(t *tensor.Dense) (retVal [][]uint64, err error) + +//go:linkname Tensor3U64 gorgonia.org/tensor.nativeDenseTensor3U64 + +// Tensor3U64 converts a *Dense into a [][][]uint64. +// If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. +func Tensor3U64(t *tensor.Dense) (retVal [][][]uint64, err error) + +//go:linkname VectorF32 gorgonia.org/tensor.nativeDenseVectorF32 + +// VectorF32 converts a *Dense into a []float32 +// If the *Dense does not represent a vector of the wanted type, it will return +// an error. +func VectorF32(t *tensor.Dense) (retVal []float32, err error) + +//go:linkname MatrixF32 gorgonia.org/tensor.nativeDenseMatrixF32 + +// MatrixF32 converts a *Dense into a [][]float32 +// If the *Dense does not represent a matrix of the wanted type, it +// will return an error. +func MatrixF32(t *tensor.Dense) (retVal [][]float32, err error) + +//go:linkname Tensor3F32 gorgonia.org/tensor.nativeDenseTensor3F32 + +// Tensor3F32 converts a *Dense into a [][][]float32. +// If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. +func Tensor3F32(t *tensor.Dense) (retVal [][][]float32, err error) + +//go:linkname VectorF64 gorgonia.org/tensor.nativeDenseVectorF64 + +// VectorF64 converts a *Dense into a []float64 +// If the *Dense does not represent a vector of the wanted type, it will return +// an error. +func VectorF64(t *tensor.Dense) (retVal []float64, err error) + +//go:linkname MatrixF64 gorgonia.org/tensor.nativeDenseMatrixF64 + +// MatrixF64 converts a *Dense into a [][]float64 +// If the *Dense does not represent a matrix of the wanted type, it +// will return an error. +func MatrixF64(t *tensor.Dense) (retVal [][]float64, err error) + +//go:linkname Tensor3F64 gorgonia.org/tensor.nativeDenseTensor3F64 + +// Tensor3F64 converts a *Dense into a [][][]float64. +// If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. +func Tensor3F64(t *tensor.Dense) (retVal [][][]float64, err error) + +//go:linkname VectorC64 gorgonia.org/tensor.nativeDenseVectorC64 + +// VectorC64 converts a *Dense into a []complex64 +// If the *Dense does not represent a vector of the wanted type, it will return +// an error. +func VectorC64(t *tensor.Dense) (retVal []complex64, err error) + +//go:linkname MatrixC64 gorgonia.org/tensor.nativeDenseMatrixC64 + +// MatrixC64 converts a *Dense into a [][]complex64 +// If the *Dense does not represent a matrix of the wanted type, it +// will return an error. +func MatrixC64(t *tensor.Dense) (retVal [][]complex64, err error) + +//go:linkname Tensor3C64 gorgonia.org/tensor.nativeDenseTensor3C64 + +// Tensor3C64 converts a *Dense into a [][][]complex64. +// If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. +func Tensor3C64(t *tensor.Dense) (retVal [][][]complex64, err error) + +//go:linkname VectorC128 gorgonia.org/tensor.nativeDenseVectorC128 + +// VectorC128 converts a *Dense into a []complex128 +// If the *Dense does not represent a vector of the wanted type, it will return +// an error. +func VectorC128(t *tensor.Dense) (retVal []complex128, err error) + +//go:linkname MatrixC128 gorgonia.org/tensor.nativeDenseMatrixC128 + +// MatrixC128 converts a *Dense into a [][]complex128 +// If the *Dense does not represent a matrix of the wanted type, it +// will return an error. +func MatrixC128(t *tensor.Dense) (retVal [][]complex128, err error) + +//go:linkname Tensor3C128 gorgonia.org/tensor.nativeDenseTensor3C128 + +// Tensor3C128 converts a *Dense into a [][][]complex128. +// If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. +func Tensor3C128(t *tensor.Dense) (retVal [][][]complex128, err error) + +//go:linkname VectorStr gorgonia.org/tensor.nativeDenseVectorStr + +// VectorStr converts a *Dense into a []string +// If the *Dense does not represent a vector of the wanted type, it will return +// an error. +func VectorStr(t *tensor.Dense) (retVal []string, err error) + +//go:linkname MatrixStr gorgonia.org/tensor.nativeDenseMatrixStr + +// MatrixStr converts a *Dense into a [][]string +// If the *Dense does not represent a matrix of the wanted type, it +// will return an error. +func MatrixStr(t *tensor.Dense) (retVal [][]string, err error) + +//go:linkname Tensor3Str gorgonia.org/tensor.nativeDenseTensor3Str + +// Tensor3Str converts a *Dense into a [][][]string. +// If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. +func Tensor3Str(t *tensor.Dense) (retVal [][][]string, err error) diff --git a/native/iterator_native_purego.go b/native/iterator_native_purego.go new file mode 100644 index 0000000..aba1b50 --- /dev/null +++ b/native/iterator_native_purego.go @@ -0,0 +1,1133 @@ +//go:build purego +// +build purego + +package native + +// Code generated by genlib2. DO NOT EDIT. + +import ( + "reflect" + "unsafe" + + . "gorgonia.org/tensor" +) + +/* Native Iterables for bool */ + +// VectorB converts a *Dense into a []bool +// If the *Dense does not represent a vector of the wanted type, it will return +// an error. +func VectorB(t *Dense) (retVal []bool, err error) { + if err = checkNativeIterable(t, 1, Bool); err != nil { + return nil, err + } + return t.Bools(), nil +} + +// MatrixB converts a *Dense into a [][]bool +// If the *Dense does not represent a matrix of the wanted type, it +// will return an error. +func MatrixB(t *Dense) (retVal [][]bool, err error) { + if err = checkNativeIterable(t, 2, Bool); err != nil { + return nil, err + } + + data := t.Bools() + shape := t.Shape() + strides := t.Strides() + + rows := shape[0] + cols := shape[1] + rowStride := strides[0] + retVal = make([][]bool, rows) + for i := range retVal { + start := i * rowStride + retVal[i] = make([]bool, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + return +} + +// Tensor3B converts a *Dense into a [][][]bool. +// If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. +func Tensor3B(t *Dense) (retVal [][][]bool, err error) { + if err = checkNativeIterable(t, 3, Bool); err != nil { + return nil, err + } + + data := t.Bools() + shape := t.Shape() + strides := t.Strides() + + layers := shape[0] + rows := shape[1] + cols := shape[2] + layerStride := strides[0] + rowStride := strides[1] + retVal = make([][][]bool, layers) + for i := range retVal { + retVal[i] = make([][]bool, rows) + for j := range retVal[i] { + retVal[i][j] = make([]bool, 0) + start := i*layerStride + j*rowStride + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + } + return +} + +/* Native Iterables for int */ + +// VectorI converts a *Dense into a []int +// If the *Dense does not represent a vector of the wanted type, it will return +// an error. +func VectorI(t *Dense) (retVal []int, err error) { + if err = checkNativeIterable(t, 1, Int); err != nil { + return nil, err + } + return t.Ints(), nil +} + +// MatrixI converts a *Dense into a [][]int +// If the *Dense does not represent a matrix of the wanted type, it +// will return an error. +func MatrixI(t *Dense) (retVal [][]int, err error) { + if err = checkNativeIterable(t, 2, Int); err != nil { + return nil, err + } + + data := t.Ints() + shape := t.Shape() + strides := t.Strides() + + rows := shape[0] + cols := shape[1] + rowStride := strides[0] + retVal = make([][]int, rows) + for i := range retVal { + start := i * rowStride + retVal[i] = make([]int, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + return +} + +// Tensor3I converts a *Dense into a [][][]int. +// If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. +func Tensor3I(t *Dense) (retVal [][][]int, err error) { + if err = checkNativeIterable(t, 3, Int); err != nil { + return nil, err + } + + data := t.Ints() + shape := t.Shape() + strides := t.Strides() + + layers := shape[0] + rows := shape[1] + cols := shape[2] + layerStride := strides[0] + rowStride := strides[1] + retVal = make([][][]int, layers) + for i := range retVal { + retVal[i] = make([][]int, rows) + for j := range retVal[i] { + retVal[i][j] = make([]int, 0) + start := i*layerStride + j*rowStride + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + } + return +} + +/* Native Iterables for int8 */ + +// VectorI8 converts a *Dense into a []int8 +// If the *Dense does not represent a vector of the wanted type, it will return +// an error. +func VectorI8(t *Dense) (retVal []int8, err error) { + if err = checkNativeIterable(t, 1, Int8); err != nil { + return nil, err + } + return t.Int8s(), nil +} + +// MatrixI8 converts a *Dense into a [][]int8 +// If the *Dense does not represent a matrix of the wanted type, it +// will return an error. +func MatrixI8(t *Dense) (retVal [][]int8, err error) { + if err = checkNativeIterable(t, 2, Int8); err != nil { + return nil, err + } + + data := t.Int8s() + shape := t.Shape() + strides := t.Strides() + + rows := shape[0] + cols := shape[1] + rowStride := strides[0] + retVal = make([][]int8, rows) + for i := range retVal { + start := i * rowStride + retVal[i] = make([]int8, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + return +} + +// Tensor3I8 converts a *Dense into a [][][]int8. +// If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. +func Tensor3I8(t *Dense) (retVal [][][]int8, err error) { + if err = checkNativeIterable(t, 3, Int8); err != nil { + return nil, err + } + + data := t.Int8s() + shape := t.Shape() + strides := t.Strides() + + layers := shape[0] + rows := shape[1] + cols := shape[2] + layerStride := strides[0] + rowStride := strides[1] + retVal = make([][][]int8, layers) + for i := range retVal { + retVal[i] = make([][]int8, rows) + for j := range retVal[i] { + retVal[i][j] = make([]int8, 0) + start := i*layerStride + j*rowStride + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + } + return +} + +/* Native Iterables for int16 */ + +// VectorI16 converts a *Dense into a []int16 +// If the *Dense does not represent a vector of the wanted type, it will return +// an error. +func VectorI16(t *Dense) (retVal []int16, err error) { + if err = checkNativeIterable(t, 1, Int16); err != nil { + return nil, err + } + return t.Int16s(), nil +} + +// MatrixI16 converts a *Dense into a [][]int16 +// If the *Dense does not represent a matrix of the wanted type, it +// will return an error. +func MatrixI16(t *Dense) (retVal [][]int16, err error) { + if err = checkNativeIterable(t, 2, Int16); err != nil { + return nil, err + } + + data := t.Int16s() + shape := t.Shape() + strides := t.Strides() + + rows := shape[0] + cols := shape[1] + rowStride := strides[0] + retVal = make([][]int16, rows) + for i := range retVal { + start := i * rowStride + retVal[i] = make([]int16, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + return +} + +// Tensor3I16 converts a *Dense into a [][][]int16. +// If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. +func Tensor3I16(t *Dense) (retVal [][][]int16, err error) { + if err = checkNativeIterable(t, 3, Int16); err != nil { + return nil, err + } + + data := t.Int16s() + shape := t.Shape() + strides := t.Strides() + + layers := shape[0] + rows := shape[1] + cols := shape[2] + layerStride := strides[0] + rowStride := strides[1] + retVal = make([][][]int16, layers) + for i := range retVal { + retVal[i] = make([][]int16, rows) + for j := range retVal[i] { + retVal[i][j] = make([]int16, 0) + start := i*layerStride + j*rowStride + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + } + return +} + +/* Native Iterables for int32 */ + +// VectorI32 converts a *Dense into a []int32 +// If the *Dense does not represent a vector of the wanted type, it will return +// an error. +func VectorI32(t *Dense) (retVal []int32, err error) { + if err = checkNativeIterable(t, 1, Int32); err != nil { + return nil, err + } + return t.Int32s(), nil +} + +// MatrixI32 converts a *Dense into a [][]int32 +// If the *Dense does not represent a matrix of the wanted type, it +// will return an error. +func MatrixI32(t *Dense) (retVal [][]int32, err error) { + if err = checkNativeIterable(t, 2, Int32); err != nil { + return nil, err + } + + data := t.Int32s() + shape := t.Shape() + strides := t.Strides() + + rows := shape[0] + cols := shape[1] + rowStride := strides[0] + retVal = make([][]int32, rows) + for i := range retVal { + start := i * rowStride + retVal[i] = make([]int32, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + return +} + +// Tensor3I32 converts a *Dense into a [][][]int32. +// If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. +func Tensor3I32(t *Dense) (retVal [][][]int32, err error) { + if err = checkNativeIterable(t, 3, Int32); err != nil { + return nil, err + } + + data := t.Int32s() + shape := t.Shape() + strides := t.Strides() + + layers := shape[0] + rows := shape[1] + cols := shape[2] + layerStride := strides[0] + rowStride := strides[1] + retVal = make([][][]int32, layers) + for i := range retVal { + retVal[i] = make([][]int32, rows) + for j := range retVal[i] { + retVal[i][j] = make([]int32, 0) + start := i*layerStride + j*rowStride + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + } + return +} + +/* Native Iterables for int64 */ + +// VectorI64 converts a *Dense into a []int64 +// If the *Dense does not represent a vector of the wanted type, it will return +// an error. +func VectorI64(t *Dense) (retVal []int64, err error) { + if err = checkNativeIterable(t, 1, Int64); err != nil { + return nil, err + } + return t.Int64s(), nil +} + +// MatrixI64 converts a *Dense into a [][]int64 +// If the *Dense does not represent a matrix of the wanted type, it +// will return an error. +func MatrixI64(t *Dense) (retVal [][]int64, err error) { + if err = checkNativeIterable(t, 2, Int64); err != nil { + return nil, err + } + + data := t.Int64s() + shape := t.Shape() + strides := t.Strides() + + rows := shape[0] + cols := shape[1] + rowStride := strides[0] + retVal = make([][]int64, rows) + for i := range retVal { + start := i * rowStride + retVal[i] = make([]int64, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + return +} + +// Tensor3I64 converts a *Dense into a [][][]int64. +// If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. +func Tensor3I64(t *Dense) (retVal [][][]int64, err error) { + if err = checkNativeIterable(t, 3, Int64); err != nil { + return nil, err + } + + data := t.Int64s() + shape := t.Shape() + strides := t.Strides() + + layers := shape[0] + rows := shape[1] + cols := shape[2] + layerStride := strides[0] + rowStride := strides[1] + retVal = make([][][]int64, layers) + for i := range retVal { + retVal[i] = make([][]int64, rows) + for j := range retVal[i] { + retVal[i][j] = make([]int64, 0) + start := i*layerStride + j*rowStride + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + } + return +} + +/* Native Iterables for uint */ + +// VectorU converts a *Dense into a []uint +// If the *Dense does not represent a vector of the wanted type, it will return +// an error. +func VectorU(t *Dense) (retVal []uint, err error) { + if err = checkNativeIterable(t, 1, Uint); err != nil { + return nil, err + } + return t.Uints(), nil +} + +// MatrixU converts a *Dense into a [][]uint +// If the *Dense does not represent a matrix of the wanted type, it +// will return an error. +func MatrixU(t *Dense) (retVal [][]uint, err error) { + if err = checkNativeIterable(t, 2, Uint); err != nil { + return nil, err + } + + data := t.Uints() + shape := t.Shape() + strides := t.Strides() + + rows := shape[0] + cols := shape[1] + rowStride := strides[0] + retVal = make([][]uint, rows) + for i := range retVal { + start := i * rowStride + retVal[i] = make([]uint, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + return +} + +// Tensor3U converts a *Dense into a [][][]uint. +// If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. +func Tensor3U(t *Dense) (retVal [][][]uint, err error) { + if err = checkNativeIterable(t, 3, Uint); err != nil { + return nil, err + } + + data := t.Uints() + shape := t.Shape() + strides := t.Strides() + + layers := shape[0] + rows := shape[1] + cols := shape[2] + layerStride := strides[0] + rowStride := strides[1] + retVal = make([][][]uint, layers) + for i := range retVal { + retVal[i] = make([][]uint, rows) + for j := range retVal[i] { + retVal[i][j] = make([]uint, 0) + start := i*layerStride + j*rowStride + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + } + return +} + +/* Native Iterables for uint8 */ + +// VectorU8 converts a *Dense into a []uint8 +// If the *Dense does not represent a vector of the wanted type, it will return +// an error. +func VectorU8(t *Dense) (retVal []uint8, err error) { + if err = checkNativeIterable(t, 1, Uint8); err != nil { + return nil, err + } + return t.Uint8s(), nil +} + +// MatrixU8 converts a *Dense into a [][]uint8 +// If the *Dense does not represent a matrix of the wanted type, it +// will return an error. +func MatrixU8(t *Dense) (retVal [][]uint8, err error) { + if err = checkNativeIterable(t, 2, Uint8); err != nil { + return nil, err + } + + data := t.Uint8s() + shape := t.Shape() + strides := t.Strides() + + rows := shape[0] + cols := shape[1] + rowStride := strides[0] + retVal = make([][]uint8, rows) + for i := range retVal { + start := i * rowStride + retVal[i] = make([]uint8, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + return +} + +// Tensor3U8 converts a *Dense into a [][][]uint8. +// If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. +func Tensor3U8(t *Dense) (retVal [][][]uint8, err error) { + if err = checkNativeIterable(t, 3, Uint8); err != nil { + return nil, err + } + + data := t.Uint8s() + shape := t.Shape() + strides := t.Strides() + + layers := shape[0] + rows := shape[1] + cols := shape[2] + layerStride := strides[0] + rowStride := strides[1] + retVal = make([][][]uint8, layers) + for i := range retVal { + retVal[i] = make([][]uint8, rows) + for j := range retVal[i] { + retVal[i][j] = make([]uint8, 0) + start := i*layerStride + j*rowStride + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + } + return +} + +/* Native Iterables for uint16 */ + +// VectorU16 converts a *Dense into a []uint16 +// If the *Dense does not represent a vector of the wanted type, it will return +// an error. +func VectorU16(t *Dense) (retVal []uint16, err error) { + if err = checkNativeIterable(t, 1, Uint16); err != nil { + return nil, err + } + return t.Uint16s(), nil +} + +// MatrixU16 converts a *Dense into a [][]uint16 +// If the *Dense does not represent a matrix of the wanted type, it +// will return an error. +func MatrixU16(t *Dense) (retVal [][]uint16, err error) { + if err = checkNativeIterable(t, 2, Uint16); err != nil { + return nil, err + } + + data := t.Uint16s() + shape := t.Shape() + strides := t.Strides() + + rows := shape[0] + cols := shape[1] + rowStride := strides[0] + retVal = make([][]uint16, rows) + for i := range retVal { + start := i * rowStride + retVal[i] = make([]uint16, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + return +} + +// Tensor3U16 converts a *Dense into a [][][]uint16. +// If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. +func Tensor3U16(t *Dense) (retVal [][][]uint16, err error) { + if err = checkNativeIterable(t, 3, Uint16); err != nil { + return nil, err + } + + data := t.Uint16s() + shape := t.Shape() + strides := t.Strides() + + layers := shape[0] + rows := shape[1] + cols := shape[2] + layerStride := strides[0] + rowStride := strides[1] + retVal = make([][][]uint16, layers) + for i := range retVal { + retVal[i] = make([][]uint16, rows) + for j := range retVal[i] { + retVal[i][j] = make([]uint16, 0) + start := i*layerStride + j*rowStride + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + } + return +} + +/* Native Iterables for uint32 */ + +// VectorU32 converts a *Dense into a []uint32 +// If the *Dense does not represent a vector of the wanted type, it will return +// an error. +func VectorU32(t *Dense) (retVal []uint32, err error) { + if err = checkNativeIterable(t, 1, Uint32); err != nil { + return nil, err + } + return t.Uint32s(), nil +} + +// MatrixU32 converts a *Dense into a [][]uint32 +// If the *Dense does not represent a matrix of the wanted type, it +// will return an error. +func MatrixU32(t *Dense) (retVal [][]uint32, err error) { + if err = checkNativeIterable(t, 2, Uint32); err != nil { + return nil, err + } + + data := t.Uint32s() + shape := t.Shape() + strides := t.Strides() + + rows := shape[0] + cols := shape[1] + rowStride := strides[0] + retVal = make([][]uint32, rows) + for i := range retVal { + start := i * rowStride + retVal[i] = make([]uint32, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + return +} + +// Tensor3U32 converts a *Dense into a [][][]uint32. +// If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. +func Tensor3U32(t *Dense) (retVal [][][]uint32, err error) { + if err = checkNativeIterable(t, 3, Uint32); err != nil { + return nil, err + } + + data := t.Uint32s() + shape := t.Shape() + strides := t.Strides() + + layers := shape[0] + rows := shape[1] + cols := shape[2] + layerStride := strides[0] + rowStride := strides[1] + retVal = make([][][]uint32, layers) + for i := range retVal { + retVal[i] = make([][]uint32, rows) + for j := range retVal[i] { + retVal[i][j] = make([]uint32, 0) + start := i*layerStride + j*rowStride + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + } + return +} + +/* Native Iterables for uint64 */ + +// VectorU64 converts a *Dense into a []uint64 +// If the *Dense does not represent a vector of the wanted type, it will return +// an error. +func VectorU64(t *Dense) (retVal []uint64, err error) { + if err = checkNativeIterable(t, 1, Uint64); err != nil { + return nil, err + } + return t.Uint64s(), nil +} + +// MatrixU64 converts a *Dense into a [][]uint64 +// If the *Dense does not represent a matrix of the wanted type, it +// will return an error. +func MatrixU64(t *Dense) (retVal [][]uint64, err error) { + if err = checkNativeIterable(t, 2, Uint64); err != nil { + return nil, err + } + + data := t.Uint64s() + shape := t.Shape() + strides := t.Strides() + + rows := shape[0] + cols := shape[1] + rowStride := strides[0] + retVal = make([][]uint64, rows) + for i := range retVal { + start := i * rowStride + retVal[i] = make([]uint64, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + return +} + +// Tensor3U64 converts a *Dense into a [][][]uint64. +// If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. +func Tensor3U64(t *Dense) (retVal [][][]uint64, err error) { + if err = checkNativeIterable(t, 3, Uint64); err != nil { + return nil, err + } + + data := t.Uint64s() + shape := t.Shape() + strides := t.Strides() + + layers := shape[0] + rows := shape[1] + cols := shape[2] + layerStride := strides[0] + rowStride := strides[1] + retVal = make([][][]uint64, layers) + for i := range retVal { + retVal[i] = make([][]uint64, rows) + for j := range retVal[i] { + retVal[i][j] = make([]uint64, 0) + start := i*layerStride + j*rowStride + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + } + return +} + +/* Native Iterables for float32 */ + +// VectorF32 converts a *Dense into a []float32 +// If the *Dense does not represent a vector of the wanted type, it will return +// an error. +func VectorF32(t *Dense) (retVal []float32, err error) { + if err = checkNativeIterable(t, 1, Float32); err != nil { + return nil, err + } + return t.Float32s(), nil +} + +// MatrixF32 converts a *Dense into a [][]float32 +// If the *Dense does not represent a matrix of the wanted type, it +// will return an error. +func MatrixF32(t *Dense) (retVal [][]float32, err error) { + if err = checkNativeIterable(t, 2, Float32); err != nil { + return nil, err + } + + data := t.Float32s() + shape := t.Shape() + strides := t.Strides() + + rows := shape[0] + cols := shape[1] + rowStride := strides[0] + retVal = make([][]float32, rows) + for i := range retVal { + start := i * rowStride + retVal[i] = make([]float32, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + return +} + +// Tensor3F32 converts a *Dense into a [][][]float32. +// If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. +func Tensor3F32(t *Dense) (retVal [][][]float32, err error) { + if err = checkNativeIterable(t, 3, Float32); err != nil { + return nil, err + } + + data := t.Float32s() + shape := t.Shape() + strides := t.Strides() + + layers := shape[0] + rows := shape[1] + cols := shape[2] + layerStride := strides[0] + rowStride := strides[1] + retVal = make([][][]float32, layers) + for i := range retVal { + retVal[i] = make([][]float32, rows) + for j := range retVal[i] { + retVal[i][j] = make([]float32, 0) + start := i*layerStride + j*rowStride + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + } + return +} + +/* Native Iterables for float64 */ + +// VectorF64 converts a *Dense into a []float64 +// If the *Dense does not represent a vector of the wanted type, it will return +// an error. +func VectorF64(t *Dense) (retVal []float64, err error) { + if err = checkNativeIterable(t, 1, Float64); err != nil { + return nil, err + } + return t.Float64s(), nil +} + +// MatrixF64 converts a *Dense into a [][]float64 +// If the *Dense does not represent a matrix of the wanted type, it +// will return an error. +func MatrixF64(t *Dense) (retVal [][]float64, err error) { + if err = checkNativeIterable(t, 2, Float64); err != nil { + return nil, err + } + + data := t.Float64s() + shape := t.Shape() + strides := t.Strides() + + rows := shape[0] + cols := shape[1] + rowStride := strides[0] + retVal = make([][]float64, rows) + for i := range retVal { + start := i * rowStride + retVal[i] = make([]float64, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + return +} + +// Tensor3F64 converts a *Dense into a [][][]float64. +// If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. +func Tensor3F64(t *Dense) (retVal [][][]float64, err error) { + if err = checkNativeIterable(t, 3, Float64); err != nil { + return nil, err + } + + data := t.Float64s() + shape := t.Shape() + strides := t.Strides() + + layers := shape[0] + rows := shape[1] + cols := shape[2] + layerStride := strides[0] + rowStride := strides[1] + retVal = make([][][]float64, layers) + for i := range retVal { + retVal[i] = make([][]float64, rows) + for j := range retVal[i] { + retVal[i][j] = make([]float64, 0) + start := i*layerStride + j*rowStride + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + } + return +} + +/* Native Iterables for complex64 */ + +// VectorC64 converts a *Dense into a []complex64 +// If the *Dense does not represent a vector of the wanted type, it will return +// an error. +func VectorC64(t *Dense) (retVal []complex64, err error) { + if err = checkNativeIterable(t, 1, Complex64); err != nil { + return nil, err + } + return t.Complex64s(), nil +} + +// MatrixC64 converts a *Dense into a [][]complex64 +// If the *Dense does not represent a matrix of the wanted type, it +// will return an error. +func MatrixC64(t *Dense) (retVal [][]complex64, err error) { + if err = checkNativeIterable(t, 2, Complex64); err != nil { + return nil, err + } + + data := t.Complex64s() + shape := t.Shape() + strides := t.Strides() + + rows := shape[0] + cols := shape[1] + rowStride := strides[0] + retVal = make([][]complex64, rows) + for i := range retVal { + start := i * rowStride + retVal[i] = make([]complex64, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + return +} + +// Tensor3C64 converts a *Dense into a [][][]complex64. +// If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. +func Tensor3C64(t *Dense) (retVal [][][]complex64, err error) { + if err = checkNativeIterable(t, 3, Complex64); err != nil { + return nil, err + } + + data := t.Complex64s() + shape := t.Shape() + strides := t.Strides() + + layers := shape[0] + rows := shape[1] + cols := shape[2] + layerStride := strides[0] + rowStride := strides[1] + retVal = make([][][]complex64, layers) + for i := range retVal { + retVal[i] = make([][]complex64, rows) + for j := range retVal[i] { + retVal[i][j] = make([]complex64, 0) + start := i*layerStride + j*rowStride + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + } + return +} + +/* Native Iterables for complex128 */ + +// VectorC128 converts a *Dense into a []complex128 +// If the *Dense does not represent a vector of the wanted type, it will return +// an error. +func VectorC128(t *Dense) (retVal []complex128, err error) { + if err = checkNativeIterable(t, 1, Complex128); err != nil { + return nil, err + } + return t.Complex128s(), nil +} + +// MatrixC128 converts a *Dense into a [][]complex128 +// If the *Dense does not represent a matrix of the wanted type, it +// will return an error. +func MatrixC128(t *Dense) (retVal [][]complex128, err error) { + if err = checkNativeIterable(t, 2, Complex128); err != nil { + return nil, err + } + + data := t.Complex128s() + shape := t.Shape() + strides := t.Strides() + + rows := shape[0] + cols := shape[1] + rowStride := strides[0] + retVal = make([][]complex128, rows) + for i := range retVal { + start := i * rowStride + retVal[i] = make([]complex128, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + return +} + +// Tensor3C128 converts a *Dense into a [][][]complex128. +// If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. +func Tensor3C128(t *Dense) (retVal [][][]complex128, err error) { + if err = checkNativeIterable(t, 3, Complex128); err != nil { + return nil, err + } + + data := t.Complex128s() + shape := t.Shape() + strides := t.Strides() + + layers := shape[0] + rows := shape[1] + cols := shape[2] + layerStride := strides[0] + rowStride := strides[1] + retVal = make([][][]complex128, layers) + for i := range retVal { + retVal[i] = make([][]complex128, rows) + for j := range retVal[i] { + retVal[i][j] = make([]complex128, 0) + start := i*layerStride + j*rowStride + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + } + return +} + +/* Native Iterables for string */ + +// VectorStr converts a *Dense into a []string +// If the *Dense does not represent a vector of the wanted type, it will return +// an error. +func VectorStr(t *Dense) (retVal []string, err error) { + if err = checkNativeIterable(t, 1, String); err != nil { + return nil, err + } + return t.Strings(), nil +} + +// MatrixStr converts a *Dense into a [][]string +// If the *Dense does not represent a matrix of the wanted type, it +// will return an error. +func MatrixStr(t *Dense) (retVal [][]string, err error) { + if err = checkNativeIterable(t, 2, String); err != nil { + return nil, err + } + + data := t.Strings() + shape := t.Shape() + strides := t.Strides() + + rows := shape[0] + cols := shape[1] + rowStride := strides[0] + retVal = make([][]string, rows) + for i := range retVal { + start := i * rowStride + retVal[i] = make([]string, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + return +} + +// Tensor3Str converts a *Dense into a [][][]string. +// If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. +func Tensor3Str(t *Dense) (retVal [][][]string, err error) { + if err = checkNativeIterable(t, 3, String); err != nil { + return nil, err + } + + data := t.Strings() + shape := t.Shape() + strides := t.Strides() + + layers := shape[0] + rows := shape[1] + cols := shape[2] + layerStride := strides[0] + rowStride := strides[1] + retVal = make([][][]string, layers) + for i := range retVal { + retVal[i] = make([][]string, rows) + for j := range retVal[i] { + retVal[i][j] = make([]string, 0) + start := i*layerStride + j*rowStride + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + } + return +} diff --git a/native/iterator_native_test.go b/native/iterator_native_test.go new file mode 100644 index 0000000..2e99966 --- /dev/null +++ b/native/iterator_native_test.go @@ -0,0 +1,634 @@ +package native + +// Code generated by genlib2. DO NOT EDIT. + +import ( + "testing" + + "github.com/stretchr/testify/assert" + . "gorgonia.org/tensor" +) + +func Test_VectorB(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(Of(Bool), WithShape(6)) + it, err := VectorB(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(6, len(it)) +} + +func Test_MatrixB(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(Of(Bool), WithShape(2, 3)) + it, err := MatrixB(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) +} + +func Test_Tensor3B(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(Of(Bool), WithShape(2, 3, 4)) + it, err := Tensor3B(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) + assert.Equal(4, len(it[0][0])) +} + +func Test_VectorI(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Int, 0, 6)), WithShape(6)) + it, err := VectorI(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(6, len(it)) +} + +func Test_MatrixI(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Int, 0, 6)), WithShape(2, 3)) + it, err := MatrixI(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) +} + +func Test_Tensor3I(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Int, 0, 24)), WithShape(2, 3, 4)) + it, err := Tensor3I(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) + assert.Equal(4, len(it[0][0])) +} + +func Test_VectorI8(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Int8, 0, 6)), WithShape(6)) + it, err := VectorI8(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(6, len(it)) +} + +func Test_MatrixI8(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Int8, 0, 6)), WithShape(2, 3)) + it, err := MatrixI8(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) +} + +func Test_Tensor3I8(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Int8, 0, 24)), WithShape(2, 3, 4)) + it, err := Tensor3I8(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) + assert.Equal(4, len(it[0][0])) +} + +func Test_VectorI16(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Int16, 0, 6)), WithShape(6)) + it, err := VectorI16(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(6, len(it)) +} + +func Test_MatrixI16(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Int16, 0, 6)), WithShape(2, 3)) + it, err := MatrixI16(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) +} + +func Test_Tensor3I16(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Int16, 0, 24)), WithShape(2, 3, 4)) + it, err := Tensor3I16(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) + assert.Equal(4, len(it[0][0])) +} + +func Test_VectorI32(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Int32, 0, 6)), WithShape(6)) + it, err := VectorI32(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(6, len(it)) +} + +func Test_MatrixI32(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Int32, 0, 6)), WithShape(2, 3)) + it, err := MatrixI32(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) +} + +func Test_Tensor3I32(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Int32, 0, 24)), WithShape(2, 3, 4)) + it, err := Tensor3I32(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) + assert.Equal(4, len(it[0][0])) +} + +func Test_VectorI64(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Int64, 0, 6)), WithShape(6)) + it, err := VectorI64(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(6, len(it)) +} + +func Test_MatrixI64(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Int64, 0, 6)), WithShape(2, 3)) + it, err := MatrixI64(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) +} + +func Test_Tensor3I64(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Int64, 0, 24)), WithShape(2, 3, 4)) + it, err := Tensor3I64(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) + assert.Equal(4, len(it[0][0])) +} + +func Test_VectorU(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Uint, 0, 6)), WithShape(6)) + it, err := VectorU(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(6, len(it)) +} + +func Test_MatrixU(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Uint, 0, 6)), WithShape(2, 3)) + it, err := MatrixU(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) +} + +func Test_Tensor3U(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Uint, 0, 24)), WithShape(2, 3, 4)) + it, err := Tensor3U(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) + assert.Equal(4, len(it[0][0])) +} + +func Test_VectorU8(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Uint8, 0, 6)), WithShape(6)) + it, err := VectorU8(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(6, len(it)) +} + +func Test_MatrixU8(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Uint8, 0, 6)), WithShape(2, 3)) + it, err := MatrixU8(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) +} + +func Test_Tensor3U8(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Uint8, 0, 24)), WithShape(2, 3, 4)) + it, err := Tensor3U8(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) + assert.Equal(4, len(it[0][0])) +} + +func Test_VectorU16(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Uint16, 0, 6)), WithShape(6)) + it, err := VectorU16(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(6, len(it)) +} + +func Test_MatrixU16(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Uint16, 0, 6)), WithShape(2, 3)) + it, err := MatrixU16(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) +} + +func Test_Tensor3U16(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Uint16, 0, 24)), WithShape(2, 3, 4)) + it, err := Tensor3U16(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) + assert.Equal(4, len(it[0][0])) +} + +func Test_VectorU32(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Uint32, 0, 6)), WithShape(6)) + it, err := VectorU32(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(6, len(it)) +} + +func Test_MatrixU32(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Uint32, 0, 6)), WithShape(2, 3)) + it, err := MatrixU32(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) +} + +func Test_Tensor3U32(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Uint32, 0, 24)), WithShape(2, 3, 4)) + it, err := Tensor3U32(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) + assert.Equal(4, len(it[0][0])) +} + +func Test_VectorU64(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Uint64, 0, 6)), WithShape(6)) + it, err := VectorU64(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(6, len(it)) +} + +func Test_MatrixU64(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Uint64, 0, 6)), WithShape(2, 3)) + it, err := MatrixU64(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) +} + +func Test_Tensor3U64(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Uint64, 0, 24)), WithShape(2, 3, 4)) + it, err := Tensor3U64(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) + assert.Equal(4, len(it[0][0])) +} + +func Test_VectorF32(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Float32, 0, 6)), WithShape(6)) + it, err := VectorF32(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(6, len(it)) +} + +func Test_MatrixF32(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Float32, 0, 6)), WithShape(2, 3)) + it, err := MatrixF32(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) +} + +func Test_Tensor3F32(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Float32, 0, 24)), WithShape(2, 3, 4)) + it, err := Tensor3F32(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) + assert.Equal(4, len(it[0][0])) +} + +func Test_VectorF64(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Float64, 0, 6)), WithShape(6)) + it, err := VectorF64(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(6, len(it)) +} + +func Test_MatrixF64(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Float64, 0, 6)), WithShape(2, 3)) + it, err := MatrixF64(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) +} + +func Test_Tensor3F64(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Float64, 0, 24)), WithShape(2, 3, 4)) + it, err := Tensor3F64(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) + assert.Equal(4, len(it[0][0])) +} + +func Test_VectorC64(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Complex64, 0, 6)), WithShape(6)) + it, err := VectorC64(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(6, len(it)) +} + +func Test_MatrixC64(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Complex64, 0, 6)), WithShape(2, 3)) + it, err := MatrixC64(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) +} + +func Test_Tensor3C64(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Complex64, 0, 24)), WithShape(2, 3, 4)) + it, err := Tensor3C64(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) + assert.Equal(4, len(it[0][0])) +} + +func Test_VectorC128(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Complex128, 0, 6)), WithShape(6)) + it, err := VectorC128(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(6, len(it)) +} + +func Test_MatrixC128(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Complex128, 0, 6)), WithShape(2, 3)) + it, err := MatrixC128(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) +} + +func Test_Tensor3C128(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Complex128, 0, 24)), WithShape(2, 3, 4)) + it, err := Tensor3C128(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) + assert.Equal(4, len(it[0][0])) +} + +func Test_VectorStr(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(Of(String), WithShape(6)) + it, err := VectorStr(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(6, len(it)) +} + +func Test_MatrixStr(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(Of(String), WithShape(2, 3)) + it, err := MatrixStr(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) +} + +func Test_Tensor3Str(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(Of(String), WithShape(2, 3, 4)) + it, err := Tensor3Str(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) + assert.Equal(4, len(it[0][0])) +} diff --git a/native/select_native.go b/native/select_native.go new file mode 100644 index 0000000..b048ae9 --- /dev/null +++ b/native/select_native.go @@ -0,0 +1,92 @@ +//go:build !purego +// +build !purego + +package native + +// Code generated by genlib2. DO NOT EDIT. + +import ( + _ "unsafe" + + "gorgonia.org/tensor" +) + +//go:linkname SelectB gorgonia.org/tensor.nativeSelectB + +// SelectB creates a slice of bools. See Example of NativeSelectF64. +func SelectB(t *tensor.Dense, axis int) (retVal [][]bool, err error) + +//go:linkname SelectI gorgonia.org/tensor.nativeSelectI + +// SelectI creates a slice of ints. See Example of NativeSelectF64. +func SelectI(t *tensor.Dense, axis int) (retVal [][]int, err error) + +//go:linkname SelectI8 gorgonia.org/tensor.nativeSelectI8 + +// SelectI8 creates a slice of int8s. See Example of NativeSelectF64. +func SelectI8(t *tensor.Dense, axis int) (retVal [][]int8, err error) + +//go:linkname SelectI16 gorgonia.org/tensor.nativeSelectI16 + +// SelectI16 creates a slice of int16s. See Example of NativeSelectF64. +func SelectI16(t *tensor.Dense, axis int) (retVal [][]int16, err error) + +//go:linkname SelectI32 gorgonia.org/tensor.nativeSelectI32 + +// SelectI32 creates a slice of int32s. See Example of NativeSelectF64. +func SelectI32(t *tensor.Dense, axis int) (retVal [][]int32, err error) + +//go:linkname SelectI64 gorgonia.org/tensor.nativeSelectI64 + +// SelectI64 creates a slice of int64s. See Example of NativeSelectF64. +func SelectI64(t *tensor.Dense, axis int) (retVal [][]int64, err error) + +//go:linkname SelectU gorgonia.org/tensor.nativeSelectU + +// SelectU creates a slice of uints. See Example of NativeSelectF64. +func SelectU(t *tensor.Dense, axis int) (retVal [][]uint, err error) + +//go:linkname SelectU8 gorgonia.org/tensor.nativeSelectU8 + +// SelectU8 creates a slice of uint8s. See Example of NativeSelectF64. +func SelectU8(t *tensor.Dense, axis int) (retVal [][]uint8, err error) + +//go:linkname SelectU16 gorgonia.org/tensor.nativeSelectU16 + +// SelectU16 creates a slice of uint16s. See Example of NativeSelectF64. +func SelectU16(t *tensor.Dense, axis int) (retVal [][]uint16, err error) + +//go:linkname SelectU32 gorgonia.org/tensor.nativeSelectU32 + +// SelectU32 creates a slice of uint32s. See Example of NativeSelectF64. +func SelectU32(t *tensor.Dense, axis int) (retVal [][]uint32, err error) + +//go:linkname SelectU64 gorgonia.org/tensor.nativeSelectU64 + +// SelectU64 creates a slice of uint64s. See Example of NativeSelectF64. +func SelectU64(t *tensor.Dense, axis int) (retVal [][]uint64, err error) + +//go:linkname SelectF32 gorgonia.org/tensor.nativeSelectF32 + +// SelectF32 creates a slice of float32s. See Example of NativeSelectF64. +func SelectF32(t *tensor.Dense, axis int) (retVal [][]float32, err error) + +//go:linkname SelectF64 gorgonia.org/tensor.nativeSelectF64 + +// SelectF64 creates a slice of float64s. See Example of NativeSelectF64. +func SelectF64(t *tensor.Dense, axis int) (retVal [][]float64, err error) + +//go:linkname SelectC64 gorgonia.org/tensor.nativeSelectC64 + +// SelectC64 creates a slice of complex64s. See Example of NativeSelectF64. +func SelectC64(t *tensor.Dense, axis int) (retVal [][]complex64, err error) + +//go:linkname SelectC128 gorgonia.org/tensor.nativeSelectC128 + +// SelectC128 creates a slice of complex128s. See Example of NativeSelectF64. +func SelectC128(t *tensor.Dense, axis int) (retVal [][]complex128, err error) + +//go:linkname SelectStr gorgonia.org/tensor.nativeSelectStr + +// SelectStr creates a slice of strings. See Example of NativeSelectF64. +func SelectStr(t *tensor.Dense, axis int) (retVal [][]string, err error) diff --git a/native/select_native_purego.go b/native/select_native_purego.go new file mode 100644 index 0000000..6285fe0 --- /dev/null +++ b/native/select_native_purego.go @@ -0,0 +1,621 @@ +//go:build purego +// +build purego + +package native + +// Code generated by genlib2. DO NOT EDIT. + +import ( + "reflect" + "unsafe" + + . "gorgonia.org/tensor" +) + +/* Native Select for bool */ + +// SelectB creates a slice of flat data types. See Example of NativeSelectF64. +func SelectB(t *Dense, axis int) (retVal [][]bool, err error) { + if err := checkNativeSelectable(t, axis, Bool); err != nil { + return nil, err + } + + switch t.Shape().Dims() { + case 0, 1: + retVal = make([][]bool, 1) + retVal[0] = t.Bools() + case 2: + if axis == 0 { + return MatrixB(t) + } + fallthrough + default: + // size := t.Shape()[axis] + data := t.Bools() + stride := t.Strides()[axis] + upper := ProdInts(t.Shape()[:axis+1]) + retVal = make([][]bool, 0, upper) + for i, r := 0, 0; r < upper; i += stride { + s := make([]bool, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&s)) + hdr.Data = uintptr(unsafe.Pointer(&data[i])) + hdr.Len = stride + hdr.Cap = stride + retVal = append(retVal, s) + r++ + } + return retVal, nil + + } + return +} + +/* Native Select for int */ + +// SelectI creates a slice of flat data types. See Example of NativeSelectF64. +func SelectI(t *Dense, axis int) (retVal [][]int, err error) { + if err := checkNativeSelectable(t, axis, Int); err != nil { + return nil, err + } + + switch t.Shape().Dims() { + case 0, 1: + retVal = make([][]int, 1) + retVal[0] = t.Ints() + case 2: + if axis == 0 { + return MatrixI(t) + } + fallthrough + default: + // size := t.Shape()[axis] + data := t.Ints() + stride := t.Strides()[axis] + upper := ProdInts(t.Shape()[:axis+1]) + retVal = make([][]int, 0, upper) + for i, r := 0, 0; r < upper; i += stride { + s := make([]int, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&s)) + hdr.Data = uintptr(unsafe.Pointer(&data[i])) + hdr.Len = stride + hdr.Cap = stride + retVal = append(retVal, s) + r++ + } + return retVal, nil + + } + return +} + +/* Native Select for int8 */ + +// SelectI8 creates a slice of flat data types. See Example of NativeSelectF64. +func SelectI8(t *Dense, axis int) (retVal [][]int8, err error) { + if err := checkNativeSelectable(t, axis, Int8); err != nil { + return nil, err + } + + switch t.Shape().Dims() { + case 0, 1: + retVal = make([][]int8, 1) + retVal[0] = t.Int8s() + case 2: + if axis == 0 { + return MatrixI8(t) + } + fallthrough + default: + // size := t.Shape()[axis] + data := t.Int8s() + stride := t.Strides()[axis] + upper := ProdInts(t.Shape()[:axis+1]) + retVal = make([][]int8, 0, upper) + for i, r := 0, 0; r < upper; i += stride { + s := make([]int8, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&s)) + hdr.Data = uintptr(unsafe.Pointer(&data[i])) + hdr.Len = stride + hdr.Cap = stride + retVal = append(retVal, s) + r++ + } + return retVal, nil + + } + return +} + +/* Native Select for int16 */ + +// SelectI16 creates a slice of flat data types. See Example of NativeSelectF64. +func SelectI16(t *Dense, axis int) (retVal [][]int16, err error) { + if err := checkNativeSelectable(t, axis, Int16); err != nil { + return nil, err + } + + switch t.Shape().Dims() { + case 0, 1: + retVal = make([][]int16, 1) + retVal[0] = t.Int16s() + case 2: + if axis == 0 { + return MatrixI16(t) + } + fallthrough + default: + // size := t.Shape()[axis] + data := t.Int16s() + stride := t.Strides()[axis] + upper := ProdInts(t.Shape()[:axis+1]) + retVal = make([][]int16, 0, upper) + for i, r := 0, 0; r < upper; i += stride { + s := make([]int16, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&s)) + hdr.Data = uintptr(unsafe.Pointer(&data[i])) + hdr.Len = stride + hdr.Cap = stride + retVal = append(retVal, s) + r++ + } + return retVal, nil + + } + return +} + +/* Native Select for int32 */ + +// SelectI32 creates a slice of flat data types. See Example of NativeSelectF64. +func SelectI32(t *Dense, axis int) (retVal [][]int32, err error) { + if err := checkNativeSelectable(t, axis, Int32); err != nil { + return nil, err + } + + switch t.Shape().Dims() { + case 0, 1: + retVal = make([][]int32, 1) + retVal[0] = t.Int32s() + case 2: + if axis == 0 { + return MatrixI32(t) + } + fallthrough + default: + // size := t.Shape()[axis] + data := t.Int32s() + stride := t.Strides()[axis] + upper := ProdInts(t.Shape()[:axis+1]) + retVal = make([][]int32, 0, upper) + for i, r := 0, 0; r < upper; i += stride { + s := make([]int32, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&s)) + hdr.Data = uintptr(unsafe.Pointer(&data[i])) + hdr.Len = stride + hdr.Cap = stride + retVal = append(retVal, s) + r++ + } + return retVal, nil + + } + return +} + +/* Native Select for int64 */ + +// SelectI64 creates a slice of flat data types. See Example of NativeSelectF64. +func SelectI64(t *Dense, axis int) (retVal [][]int64, err error) { + if err := checkNativeSelectable(t, axis, Int64); err != nil { + return nil, err + } + + switch t.Shape().Dims() { + case 0, 1: + retVal = make([][]int64, 1) + retVal[0] = t.Int64s() + case 2: + if axis == 0 { + return MatrixI64(t) + } + fallthrough + default: + // size := t.Shape()[axis] + data := t.Int64s() + stride := t.Strides()[axis] + upper := ProdInts(t.Shape()[:axis+1]) + retVal = make([][]int64, 0, upper) + for i, r := 0, 0; r < upper; i += stride { + s := make([]int64, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&s)) + hdr.Data = uintptr(unsafe.Pointer(&data[i])) + hdr.Len = stride + hdr.Cap = stride + retVal = append(retVal, s) + r++ + } + return retVal, nil + + } + return +} + +/* Native Select for uint */ + +// SelectU creates a slice of flat data types. See Example of NativeSelectF64. +func SelectU(t *Dense, axis int) (retVal [][]uint, err error) { + if err := checkNativeSelectable(t, axis, Uint); err != nil { + return nil, err + } + + switch t.Shape().Dims() { + case 0, 1: + retVal = make([][]uint, 1) + retVal[0] = t.Uints() + case 2: + if axis == 0 { + return MatrixU(t) + } + fallthrough + default: + // size := t.Shape()[axis] + data := t.Uints() + stride := t.Strides()[axis] + upper := ProdInts(t.Shape()[:axis+1]) + retVal = make([][]uint, 0, upper) + for i, r := 0, 0; r < upper; i += stride { + s := make([]uint, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&s)) + hdr.Data = uintptr(unsafe.Pointer(&data[i])) + hdr.Len = stride + hdr.Cap = stride + retVal = append(retVal, s) + r++ + } + return retVal, nil + + } + return +} + +/* Native Select for uint8 */ + +// SelectU8 creates a slice of flat data types. See Example of NativeSelectF64. +func SelectU8(t *Dense, axis int) (retVal [][]uint8, err error) { + if err := checkNativeSelectable(t, axis, Uint8); err != nil { + return nil, err + } + + switch t.Shape().Dims() { + case 0, 1: + retVal = make([][]uint8, 1) + retVal[0] = t.Uint8s() + case 2: + if axis == 0 { + return MatrixU8(t) + } + fallthrough + default: + // size := t.Shape()[axis] + data := t.Uint8s() + stride := t.Strides()[axis] + upper := ProdInts(t.Shape()[:axis+1]) + retVal = make([][]uint8, 0, upper) + for i, r := 0, 0; r < upper; i += stride { + s := make([]uint8, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&s)) + hdr.Data = uintptr(unsafe.Pointer(&data[i])) + hdr.Len = stride + hdr.Cap = stride + retVal = append(retVal, s) + r++ + } + return retVal, nil + + } + return +} + +/* Native Select for uint16 */ + +// SelectU16 creates a slice of flat data types. See Example of NativeSelectF64. +func SelectU16(t *Dense, axis int) (retVal [][]uint16, err error) { + if err := checkNativeSelectable(t, axis, Uint16); err != nil { + return nil, err + } + + switch t.Shape().Dims() { + case 0, 1: + retVal = make([][]uint16, 1) + retVal[0] = t.Uint16s() + case 2: + if axis == 0 { + return MatrixU16(t) + } + fallthrough + default: + // size := t.Shape()[axis] + data := t.Uint16s() + stride := t.Strides()[axis] + upper := ProdInts(t.Shape()[:axis+1]) + retVal = make([][]uint16, 0, upper) + for i, r := 0, 0; r < upper; i += stride { + s := make([]uint16, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&s)) + hdr.Data = uintptr(unsafe.Pointer(&data[i])) + hdr.Len = stride + hdr.Cap = stride + retVal = append(retVal, s) + r++ + } + return retVal, nil + + } + return +} + +/* Native Select for uint32 */ + +// SelectU32 creates a slice of flat data types. See Example of NativeSelectF64. +func SelectU32(t *Dense, axis int) (retVal [][]uint32, err error) { + if err := checkNativeSelectable(t, axis, Uint32); err != nil { + return nil, err + } + + switch t.Shape().Dims() { + case 0, 1: + retVal = make([][]uint32, 1) + retVal[0] = t.Uint32s() + case 2: + if axis == 0 { + return MatrixU32(t) + } + fallthrough + default: + // size := t.Shape()[axis] + data := t.Uint32s() + stride := t.Strides()[axis] + upper := ProdInts(t.Shape()[:axis+1]) + retVal = make([][]uint32, 0, upper) + for i, r := 0, 0; r < upper; i += stride { + s := make([]uint32, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&s)) + hdr.Data = uintptr(unsafe.Pointer(&data[i])) + hdr.Len = stride + hdr.Cap = stride + retVal = append(retVal, s) + r++ + } + return retVal, nil + + } + return +} + +/* Native Select for uint64 */ + +// SelectU64 creates a slice of flat data types. See Example of NativeSelectF64. +func SelectU64(t *Dense, axis int) (retVal [][]uint64, err error) { + if err := checkNativeSelectable(t, axis, Uint64); err != nil { + return nil, err + } + + switch t.Shape().Dims() { + case 0, 1: + retVal = make([][]uint64, 1) + retVal[0] = t.Uint64s() + case 2: + if axis == 0 { + return MatrixU64(t) + } + fallthrough + default: + // size := t.Shape()[axis] + data := t.Uint64s() + stride := t.Strides()[axis] + upper := ProdInts(t.Shape()[:axis+1]) + retVal = make([][]uint64, 0, upper) + for i, r := 0, 0; r < upper; i += stride { + s := make([]uint64, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&s)) + hdr.Data = uintptr(unsafe.Pointer(&data[i])) + hdr.Len = stride + hdr.Cap = stride + retVal = append(retVal, s) + r++ + } + return retVal, nil + + } + return +} + +/* Native Select for float32 */ + +// SelectF32 creates a slice of flat data types. See Example of NativeSelectF64. +func SelectF32(t *Dense, axis int) (retVal [][]float32, err error) { + if err := checkNativeSelectable(t, axis, Float32); err != nil { + return nil, err + } + + switch t.Shape().Dims() { + case 0, 1: + retVal = make([][]float32, 1) + retVal[0] = t.Float32s() + case 2: + if axis == 0 { + return MatrixF32(t) + } + fallthrough + default: + // size := t.Shape()[axis] + data := t.Float32s() + stride := t.Strides()[axis] + upper := ProdInts(t.Shape()[:axis+1]) + retVal = make([][]float32, 0, upper) + for i, r := 0, 0; r < upper; i += stride { + s := make([]float32, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&s)) + hdr.Data = uintptr(unsafe.Pointer(&data[i])) + hdr.Len = stride + hdr.Cap = stride + retVal = append(retVal, s) + r++ + } + return retVal, nil + + } + return +} + +/* Native Select for float64 */ + +// SelectF64 creates a slice of flat data types. See Example of NativeSelectF64. +func SelectF64(t *Dense, axis int) (retVal [][]float64, err error) { + if err := checkNativeSelectable(t, axis, Float64); err != nil { + return nil, err + } + + switch t.Shape().Dims() { + case 0, 1: + retVal = make([][]float64, 1) + retVal[0] = t.Float64s() + case 2: + if axis == 0 { + return MatrixF64(t) + } + fallthrough + default: + // size := t.Shape()[axis] + data := t.Float64s() + stride := t.Strides()[axis] + upper := ProdInts(t.Shape()[:axis+1]) + retVal = make([][]float64, 0, upper) + for i, r := 0, 0; r < upper; i += stride { + s := make([]float64, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&s)) + hdr.Data = uintptr(unsafe.Pointer(&data[i])) + hdr.Len = stride + hdr.Cap = stride + retVal = append(retVal, s) + r++ + } + return retVal, nil + + } + return +} + +/* Native Select for complex64 */ + +// SelectC64 creates a slice of flat data types. See Example of NativeSelectF64. +func SelectC64(t *Dense, axis int) (retVal [][]complex64, err error) { + if err := checkNativeSelectable(t, axis, Complex64); err != nil { + return nil, err + } + + switch t.Shape().Dims() { + case 0, 1: + retVal = make([][]complex64, 1) + retVal[0] = t.Complex64s() + case 2: + if axis == 0 { + return MatrixC64(t) + } + fallthrough + default: + // size := t.Shape()[axis] + data := t.Complex64s() + stride := t.Strides()[axis] + upper := ProdInts(t.Shape()[:axis+1]) + retVal = make([][]complex64, 0, upper) + for i, r := 0, 0; r < upper; i += stride { + s := make([]complex64, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&s)) + hdr.Data = uintptr(unsafe.Pointer(&data[i])) + hdr.Len = stride + hdr.Cap = stride + retVal = append(retVal, s) + r++ + } + return retVal, nil + + } + return +} + +/* Native Select for complex128 */ + +// SelectC128 creates a slice of flat data types. See Example of NativeSelectF64. +func SelectC128(t *Dense, axis int) (retVal [][]complex128, err error) { + if err := checkNativeSelectable(t, axis, Complex128); err != nil { + return nil, err + } + + switch t.Shape().Dims() { + case 0, 1: + retVal = make([][]complex128, 1) + retVal[0] = t.Complex128s() + case 2: + if axis == 0 { + return MatrixC128(t) + } + fallthrough + default: + // size := t.Shape()[axis] + data := t.Complex128s() + stride := t.Strides()[axis] + upper := ProdInts(t.Shape()[:axis+1]) + retVal = make([][]complex128, 0, upper) + for i, r := 0, 0; r < upper; i += stride { + s := make([]complex128, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&s)) + hdr.Data = uintptr(unsafe.Pointer(&data[i])) + hdr.Len = stride + hdr.Cap = stride + retVal = append(retVal, s) + r++ + } + return retVal, nil + + } + return +} + +/* Native Select for string */ + +// SelectStr creates a slice of flat data types. See Example of NativeSelectF64. +func SelectStr(t *Dense, axis int) (retVal [][]string, err error) { + if err := checkNativeSelectable(t, axis, String); err != nil { + return nil, err + } + + switch t.Shape().Dims() { + case 0, 1: + retVal = make([][]string, 1) + retVal[0] = t.Strings() + case 2: + if axis == 0 { + return MatrixStr(t) + } + fallthrough + default: + // size := t.Shape()[axis] + data := t.Strings() + stride := t.Strides()[axis] + upper := ProdInts(t.Shape()[:axis+1]) + retVal = make([][]string, 0, upper) + for i, r := 0, 0; r < upper; i += stride { + s := make([]string, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&s)) + hdr.Data = uintptr(unsafe.Pointer(&data[i])) + hdr.Len = stride + hdr.Cap = stride + retVal = append(retVal, s) + r++ + } + return retVal, nil + + } + return +} diff --git a/native/select_native_test.go b/native/select_native_test.go new file mode 100644 index 0000000..a6f247f --- /dev/null +++ b/native/select_native_test.go @@ -0,0 +1,842 @@ +package native + +// Code generated by genlib2. DO NOT EDIT. + +import ( + "testing" + + "github.com/stretchr/testify/assert" + . "gorgonia.org/tensor" +) + +func TestSelectB(t *testing.T) { + assert := assert.New(t) + var T *Dense + var err error + var x [][]bool + T = New(Of(Bool), WithShape(2, 3, 4, 5)) + if x, err = SelectB(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(20, len(x[0])) + + T = New(Of(Bool), WithShape(2, 3, 4, 5)) + if x, err = SelectB(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(60, len(x[0])) + + T = New(Of(Bool), WithShape(2, 3, 4, 5)) + if x, err = SelectB(T, 3); err != nil { + t.Fatal(err) + } + assert.Equal(120, len(x)) + assert.Equal(1, len(x[0])) + + T = New(Of(Bool), WithShape(2, 3)) + if x, err = SelectB(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(3, len(x[0])) + + T = New(Of(Bool), WithShape(2, 3)) + if x, err = SelectB(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(1, len(x[0])) + + T = New(FromScalar(false)) + if x, err = SelectB(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(1, len(x)) + assert.Equal(1, len(x[0])) + + if _, err = SelectB(T, 10); err == nil { + t.Fatal("Expected errors") + } +} + +func TestSelectI(t *testing.T) { + assert := assert.New(t) + var T *Dense + var err error + var x [][]int + T = New(Of(Int), WithShape(2, 3, 4, 5)) + if x, err = SelectI(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(20, len(x[0])) + + T = New(Of(Int), WithShape(2, 3, 4, 5)) + if x, err = SelectI(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(60, len(x[0])) + + T = New(Of(Int), WithShape(2, 3, 4, 5)) + if x, err = SelectI(T, 3); err != nil { + t.Fatal(err) + } + assert.Equal(120, len(x)) + assert.Equal(1, len(x[0])) + + T = New(Of(Int), WithShape(2, 3)) + if x, err = SelectI(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(3, len(x[0])) + + T = New(Of(Int), WithShape(2, 3)) + if x, err = SelectI(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(1, len(x[0])) + + T = New(FromScalar(int(0))) + if x, err = SelectI(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(1, len(x)) + assert.Equal(1, len(x[0])) + + if _, err = SelectI(T, 10); err == nil { + t.Fatal("Expected errors") + } +} + +func TestSelectI8(t *testing.T) { + assert := assert.New(t) + var T *Dense + var err error + var x [][]int8 + T = New(Of(Int8), WithShape(2, 3, 4, 5)) + if x, err = SelectI8(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(20, len(x[0])) + + T = New(Of(Int8), WithShape(2, 3, 4, 5)) + if x, err = SelectI8(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(60, len(x[0])) + + T = New(Of(Int8), WithShape(2, 3, 4, 5)) + if x, err = SelectI8(T, 3); err != nil { + t.Fatal(err) + } + assert.Equal(120, len(x)) + assert.Equal(1, len(x[0])) + + T = New(Of(Int8), WithShape(2, 3)) + if x, err = SelectI8(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(3, len(x[0])) + + T = New(Of(Int8), WithShape(2, 3)) + if x, err = SelectI8(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(1, len(x[0])) + + T = New(FromScalar(int8(0))) + if x, err = SelectI8(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(1, len(x)) + assert.Equal(1, len(x[0])) + + if _, err = SelectI8(T, 10); err == nil { + t.Fatal("Expected errors") + } +} + +func TestSelectI16(t *testing.T) { + assert := assert.New(t) + var T *Dense + var err error + var x [][]int16 + T = New(Of(Int16), WithShape(2, 3, 4, 5)) + if x, err = SelectI16(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(20, len(x[0])) + + T = New(Of(Int16), WithShape(2, 3, 4, 5)) + if x, err = SelectI16(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(60, len(x[0])) + + T = New(Of(Int16), WithShape(2, 3, 4, 5)) + if x, err = SelectI16(T, 3); err != nil { + t.Fatal(err) + } + assert.Equal(120, len(x)) + assert.Equal(1, len(x[0])) + + T = New(Of(Int16), WithShape(2, 3)) + if x, err = SelectI16(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(3, len(x[0])) + + T = New(Of(Int16), WithShape(2, 3)) + if x, err = SelectI16(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(1, len(x[0])) + + T = New(FromScalar(int16(0))) + if x, err = SelectI16(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(1, len(x)) + assert.Equal(1, len(x[0])) + + if _, err = SelectI16(T, 10); err == nil { + t.Fatal("Expected errors") + } +} + +func TestSelectI32(t *testing.T) { + assert := assert.New(t) + var T *Dense + var err error + var x [][]int32 + T = New(Of(Int32), WithShape(2, 3, 4, 5)) + if x, err = SelectI32(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(20, len(x[0])) + + T = New(Of(Int32), WithShape(2, 3, 4, 5)) + if x, err = SelectI32(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(60, len(x[0])) + + T = New(Of(Int32), WithShape(2, 3, 4, 5)) + if x, err = SelectI32(T, 3); err != nil { + t.Fatal(err) + } + assert.Equal(120, len(x)) + assert.Equal(1, len(x[0])) + + T = New(Of(Int32), WithShape(2, 3)) + if x, err = SelectI32(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(3, len(x[0])) + + T = New(Of(Int32), WithShape(2, 3)) + if x, err = SelectI32(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(1, len(x[0])) + + T = New(FromScalar(int32(0))) + if x, err = SelectI32(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(1, len(x)) + assert.Equal(1, len(x[0])) + + if _, err = SelectI32(T, 10); err == nil { + t.Fatal("Expected errors") + } +} + +func TestSelectI64(t *testing.T) { + assert := assert.New(t) + var T *Dense + var err error + var x [][]int64 + T = New(Of(Int64), WithShape(2, 3, 4, 5)) + if x, err = SelectI64(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(20, len(x[0])) + + T = New(Of(Int64), WithShape(2, 3, 4, 5)) + if x, err = SelectI64(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(60, len(x[0])) + + T = New(Of(Int64), WithShape(2, 3, 4, 5)) + if x, err = SelectI64(T, 3); err != nil { + t.Fatal(err) + } + assert.Equal(120, len(x)) + assert.Equal(1, len(x[0])) + + T = New(Of(Int64), WithShape(2, 3)) + if x, err = SelectI64(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(3, len(x[0])) + + T = New(Of(Int64), WithShape(2, 3)) + if x, err = SelectI64(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(1, len(x[0])) + + T = New(FromScalar(int64(0))) + if x, err = SelectI64(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(1, len(x)) + assert.Equal(1, len(x[0])) + + if _, err = SelectI64(T, 10); err == nil { + t.Fatal("Expected errors") + } +} + +func TestSelectU(t *testing.T) { + assert := assert.New(t) + var T *Dense + var err error + var x [][]uint + T = New(Of(Uint), WithShape(2, 3, 4, 5)) + if x, err = SelectU(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(20, len(x[0])) + + T = New(Of(Uint), WithShape(2, 3, 4, 5)) + if x, err = SelectU(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(60, len(x[0])) + + T = New(Of(Uint), WithShape(2, 3, 4, 5)) + if x, err = SelectU(T, 3); err != nil { + t.Fatal(err) + } + assert.Equal(120, len(x)) + assert.Equal(1, len(x[0])) + + T = New(Of(Uint), WithShape(2, 3)) + if x, err = SelectU(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(3, len(x[0])) + + T = New(Of(Uint), WithShape(2, 3)) + if x, err = SelectU(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(1, len(x[0])) + + T = New(FromScalar(uint(0))) + if x, err = SelectU(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(1, len(x)) + assert.Equal(1, len(x[0])) + + if _, err = SelectU(T, 10); err == nil { + t.Fatal("Expected errors") + } +} + +func TestSelectU8(t *testing.T) { + assert := assert.New(t) + var T *Dense + var err error + var x [][]uint8 + T = New(Of(Uint8), WithShape(2, 3, 4, 5)) + if x, err = SelectU8(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(20, len(x[0])) + + T = New(Of(Uint8), WithShape(2, 3, 4, 5)) + if x, err = SelectU8(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(60, len(x[0])) + + T = New(Of(Uint8), WithShape(2, 3, 4, 5)) + if x, err = SelectU8(T, 3); err != nil { + t.Fatal(err) + } + assert.Equal(120, len(x)) + assert.Equal(1, len(x[0])) + + T = New(Of(Uint8), WithShape(2, 3)) + if x, err = SelectU8(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(3, len(x[0])) + + T = New(Of(Uint8), WithShape(2, 3)) + if x, err = SelectU8(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(1, len(x[0])) + + T = New(FromScalar(uint8(0))) + if x, err = SelectU8(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(1, len(x)) + assert.Equal(1, len(x[0])) + + if _, err = SelectU8(T, 10); err == nil { + t.Fatal("Expected errors") + } +} + +func TestSelectU16(t *testing.T) { + assert := assert.New(t) + var T *Dense + var err error + var x [][]uint16 + T = New(Of(Uint16), WithShape(2, 3, 4, 5)) + if x, err = SelectU16(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(20, len(x[0])) + + T = New(Of(Uint16), WithShape(2, 3, 4, 5)) + if x, err = SelectU16(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(60, len(x[0])) + + T = New(Of(Uint16), WithShape(2, 3, 4, 5)) + if x, err = SelectU16(T, 3); err != nil { + t.Fatal(err) + } + assert.Equal(120, len(x)) + assert.Equal(1, len(x[0])) + + T = New(Of(Uint16), WithShape(2, 3)) + if x, err = SelectU16(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(3, len(x[0])) + + T = New(Of(Uint16), WithShape(2, 3)) + if x, err = SelectU16(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(1, len(x[0])) + + T = New(FromScalar(uint16(0))) + if x, err = SelectU16(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(1, len(x)) + assert.Equal(1, len(x[0])) + + if _, err = SelectU16(T, 10); err == nil { + t.Fatal("Expected errors") + } +} + +func TestSelectU32(t *testing.T) { + assert := assert.New(t) + var T *Dense + var err error + var x [][]uint32 + T = New(Of(Uint32), WithShape(2, 3, 4, 5)) + if x, err = SelectU32(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(20, len(x[0])) + + T = New(Of(Uint32), WithShape(2, 3, 4, 5)) + if x, err = SelectU32(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(60, len(x[0])) + + T = New(Of(Uint32), WithShape(2, 3, 4, 5)) + if x, err = SelectU32(T, 3); err != nil { + t.Fatal(err) + } + assert.Equal(120, len(x)) + assert.Equal(1, len(x[0])) + + T = New(Of(Uint32), WithShape(2, 3)) + if x, err = SelectU32(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(3, len(x[0])) + + T = New(Of(Uint32), WithShape(2, 3)) + if x, err = SelectU32(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(1, len(x[0])) + + T = New(FromScalar(uint32(0))) + if x, err = SelectU32(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(1, len(x)) + assert.Equal(1, len(x[0])) + + if _, err = SelectU32(T, 10); err == nil { + t.Fatal("Expected errors") + } +} + +func TestSelectU64(t *testing.T) { + assert := assert.New(t) + var T *Dense + var err error + var x [][]uint64 + T = New(Of(Uint64), WithShape(2, 3, 4, 5)) + if x, err = SelectU64(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(20, len(x[0])) + + T = New(Of(Uint64), WithShape(2, 3, 4, 5)) + if x, err = SelectU64(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(60, len(x[0])) + + T = New(Of(Uint64), WithShape(2, 3, 4, 5)) + if x, err = SelectU64(T, 3); err != nil { + t.Fatal(err) + } + assert.Equal(120, len(x)) + assert.Equal(1, len(x[0])) + + T = New(Of(Uint64), WithShape(2, 3)) + if x, err = SelectU64(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(3, len(x[0])) + + T = New(Of(Uint64), WithShape(2, 3)) + if x, err = SelectU64(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(1, len(x[0])) + + T = New(FromScalar(uint64(0))) + if x, err = SelectU64(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(1, len(x)) + assert.Equal(1, len(x[0])) + + if _, err = SelectU64(T, 10); err == nil { + t.Fatal("Expected errors") + } +} + +func TestSelectF32(t *testing.T) { + assert := assert.New(t) + var T *Dense + var err error + var x [][]float32 + T = New(Of(Float32), WithShape(2, 3, 4, 5)) + if x, err = SelectF32(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(20, len(x[0])) + + T = New(Of(Float32), WithShape(2, 3, 4, 5)) + if x, err = SelectF32(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(60, len(x[0])) + + T = New(Of(Float32), WithShape(2, 3, 4, 5)) + if x, err = SelectF32(T, 3); err != nil { + t.Fatal(err) + } + assert.Equal(120, len(x)) + assert.Equal(1, len(x[0])) + + T = New(Of(Float32), WithShape(2, 3)) + if x, err = SelectF32(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(3, len(x[0])) + + T = New(Of(Float32), WithShape(2, 3)) + if x, err = SelectF32(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(1, len(x[0])) + + T = New(FromScalar(float32(0))) + if x, err = SelectF32(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(1, len(x)) + assert.Equal(1, len(x[0])) + + if _, err = SelectF32(T, 10); err == nil { + t.Fatal("Expected errors") + } +} + +func TestSelectF64(t *testing.T) { + assert := assert.New(t) + var T *Dense + var err error + var x [][]float64 + T = New(Of(Float64), WithShape(2, 3, 4, 5)) + if x, err = SelectF64(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(20, len(x[0])) + + T = New(Of(Float64), WithShape(2, 3, 4, 5)) + if x, err = SelectF64(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(60, len(x[0])) + + T = New(Of(Float64), WithShape(2, 3, 4, 5)) + if x, err = SelectF64(T, 3); err != nil { + t.Fatal(err) + } + assert.Equal(120, len(x)) + assert.Equal(1, len(x[0])) + + T = New(Of(Float64), WithShape(2, 3)) + if x, err = SelectF64(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(3, len(x[0])) + + T = New(Of(Float64), WithShape(2, 3)) + if x, err = SelectF64(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(1, len(x[0])) + + T = New(FromScalar(float64(0))) + if x, err = SelectF64(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(1, len(x)) + assert.Equal(1, len(x[0])) + + if _, err = SelectF64(T, 10); err == nil { + t.Fatal("Expected errors") + } +} + +func TestSelectC64(t *testing.T) { + assert := assert.New(t) + var T *Dense + var err error + var x [][]complex64 + T = New(Of(Complex64), WithShape(2, 3, 4, 5)) + if x, err = SelectC64(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(20, len(x[0])) + + T = New(Of(Complex64), WithShape(2, 3, 4, 5)) + if x, err = SelectC64(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(60, len(x[0])) + + T = New(Of(Complex64), WithShape(2, 3, 4, 5)) + if x, err = SelectC64(T, 3); err != nil { + t.Fatal(err) + } + assert.Equal(120, len(x)) + assert.Equal(1, len(x[0])) + + T = New(Of(Complex64), WithShape(2, 3)) + if x, err = SelectC64(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(3, len(x[0])) + + T = New(Of(Complex64), WithShape(2, 3)) + if x, err = SelectC64(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(1, len(x[0])) + + T = New(FromScalar(complex64(0))) + if x, err = SelectC64(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(1, len(x)) + assert.Equal(1, len(x[0])) + + if _, err = SelectC64(T, 10); err == nil { + t.Fatal("Expected errors") + } +} + +func TestSelectC128(t *testing.T) { + assert := assert.New(t) + var T *Dense + var err error + var x [][]complex128 + T = New(Of(Complex128), WithShape(2, 3, 4, 5)) + if x, err = SelectC128(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(20, len(x[0])) + + T = New(Of(Complex128), WithShape(2, 3, 4, 5)) + if x, err = SelectC128(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(60, len(x[0])) + + T = New(Of(Complex128), WithShape(2, 3, 4, 5)) + if x, err = SelectC128(T, 3); err != nil { + t.Fatal(err) + } + assert.Equal(120, len(x)) + assert.Equal(1, len(x[0])) + + T = New(Of(Complex128), WithShape(2, 3)) + if x, err = SelectC128(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(3, len(x[0])) + + T = New(Of(Complex128), WithShape(2, 3)) + if x, err = SelectC128(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(1, len(x[0])) + + T = New(FromScalar(complex128(0))) + if x, err = SelectC128(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(1, len(x)) + assert.Equal(1, len(x[0])) + + if _, err = SelectC128(T, 10); err == nil { + t.Fatal("Expected errors") + } +} + +func TestSelectStr(t *testing.T) { + assert := assert.New(t) + var T *Dense + var err error + var x [][]string + T = New(Of(String), WithShape(2, 3, 4, 5)) + if x, err = SelectStr(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(20, len(x[0])) + + T = New(Of(String), WithShape(2, 3, 4, 5)) + if x, err = SelectStr(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(60, len(x[0])) + + T = New(Of(String), WithShape(2, 3, 4, 5)) + if x, err = SelectStr(T, 3); err != nil { + t.Fatal(err) + } + assert.Equal(120, len(x)) + assert.Equal(1, len(x[0])) + + T = New(Of(String), WithShape(2, 3)) + if x, err = SelectStr(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(3, len(x[0])) + + T = New(Of(String), WithShape(2, 3)) + if x, err = SelectStr(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(1, len(x[0])) + + T = New(FromScalar("")) + if x, err = SelectStr(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(1, len(x)) + assert.Equal(1, len(x[0])) + + if _, err = SelectStr(T, 10); err == nil { + t.Fatal("Expected errors") + } +} diff --git a/native/utils.go b/native/utils.go new file mode 100644 index 0000000..341388e --- /dev/null +++ b/native/utils.go @@ -0,0 +1,46 @@ +package native + +// Code generated by genlib2. DO NOT EDIT. + +import ( + "github.com/pkg/errors" + "gorgonia.org/dtype" + . "gorgonia.org/tensor" +) + +func checkNativeIterable(t *Dense, dims int, dt dtype.Dtype) error { + // checks: + if !t.IsNativelyAccessible() { + return errors.Errorf("Cannot convert *Dense to *mat.Dense. Data is inaccessible") + } + + if t.Shape().Dims() != dims { + return errors.Errorf("Cannot convert *Dense to native iterator. Expected number of dimension: %d, T has got %d dimensions (Shape: %v)", dims, t.Dims(), t.Shape()) + } + + if t.F() || t.RequiresIterator() { + return errors.Errorf("Not yet implemented: native matrix for colmajor or unpacked matrices") + } + + if t.Dtype() != dt { + return errors.Errorf("Conversion to native iterable only works on %v. Got %v", dt, t.Dtype()) + } + + return nil +} + +func checkNativeSelectable(t *Dense, axis int, dt dtype.Dtype) error { + if !t.IsNativelyAccessible() { + return errors.New("Cannot select on non-natively accessible data") + } + if axis >= t.Shape().Dims() && !(t.IsScalar() && axis == 0) { + return errors.Errorf("Cannot select on axis %d. Shape is %v", axis, t.Shape()) + } + if t.F() || t.RequiresIterator() { + return errors.Errorf("Not yet implemented: native select for colmajor or unpacked matrices") + } + if t.Dtype() != dt { + return errors.Errorf("Native selection only works on %v. Got %v", dt, t.Dtype()) + } + return nil +} diff --git a/optimizations_test.go b/optimizations_test.go index 18bb677..9b8afcb 100644 --- a/optimizations_test.go +++ b/optimizations_test.go @@ -1,15 +1,15 @@ -package tensor - -import ( - "testing" -) - -// this file contains tests to make sure certain algorithms/optimizations aren't crazy - -func TestRequiresIterator(t *testing.T) { - T := New(Of(Int), WithBacking([]int{1, 2, 3, 4})) - sliced, _ := T.Slice(makeRS(1, 3)) - if sliced.RequiresIterator() { - t.Errorf("Slicing on rows should not require Iterator") - } -} +package tensor + +import ( + "testing" +) + +// this file contains tests to make sure certain algorithms/optimizations aren't crazy + +func TestRequiresIterator(t *testing.T) { + T := New(Of(Int), WithBacking([]int{1, 2, 3, 4})) + sliced, _ := T.Slice(makeRS(1, 3)) + if sliced.RequiresIterator() { + t.Errorf("Slicing on rows should not require Iterator") + } +} diff --git a/perf.go b/perf.go index 573d8be..a37c610 100644 --- a/perf.go +++ b/perf.go @@ -4,6 +4,7 @@ import ( "runtime" "sync" + "gorgonia.org/dtype" "gorgonia.org/tensor/internal/storage" ) @@ -56,9 +57,7 @@ func returnHeader(hdr *storage.Header) { } func destroyHeader(hdr *storage.Header) { - hdr.Ptr = nil - hdr.L = 0 - hdr.C = 0 + hdr.Raw = nil } var densePool = make(chan *Dense, PoolSize) @@ -83,25 +82,16 @@ func ReturnTensor(t Tensor) { } switch tt := t.(type) { case *Dense: - if tt.old != nil { - ReturnAP(tt.old) - tt.old = nil - } + tt.AP.zero() if tt.transposeWith != nil { ReturnInts(tt.transposeWith) tt.transposeWith = nil } - // return AP - ReturnAP(tt.AP) - // array reset - tt.t = Dtype{} - tt.array.Ptr = nil - tt.array.L = 0 - tt.array.C = 0 - tt.array.v = nil + tt.t = dtype.Dtype{} + tt.array.Header.Raw = nil // engine and flag reset tt.e = StdEng{} @@ -109,7 +99,7 @@ func ReturnTensor(t Tensor) { tt.flag = 0 // other reset - tt.old = nil + tt.old.zero() tt.viewOf = 0 tt.transposeWith = nil @@ -124,63 +114,14 @@ func ReturnTensor(t Tensor) { } } -/* AP POOL */ - -var apPool = make(chan *AP, PoolSize) - -func borrowAP() *AP { - select { - case ap := <-apPool: - return ap - default: - return new(AP) - } - // return apPool.Get().(*AP) -} - -// BorrowAP gets an AP from the pool. USE WITH CAUTION. -func BorrowAP(dims int) *AP { - ap := borrowAP() - ap.shape = BorrowInts(dims) - ap.strides = BorrowInts(dims) - ap.shape = ap.shape[:cap(ap.shape)] - ap.strides = ap.strides[:cap(ap.strides)] - return ap -} - -// ReturnAP returns the AP to the pool. USE WITH CAUTION. -func ReturnAP(ap *AP) { - ReturnInts([]int(ap.shape)) - ReturnInts(ap.strides) - ap.fin = false - - ap.o = 0 - ap.Δ = 0 - - if len(apPool) < cap(apPool) { - apPool <- ap - } - // apPool.Put(ap) -} - /* ---------------------------------------------------------------- ------------------ Create Pools ------------------------------------------------------------------*/ /* APLIST POOL */ -var apListPool [maxAPDims]sync.Pool - // Init function func init() { - for i := range apListPool { - size := i - apListPool[i].New = func() interface{} { return make([]*AP, size) } - } - - // for i := 0; i < PoolSize; i++ { - // intsPool <- make([]int, 8, 8) - // } for i := range intsPool { size := i @@ -222,11 +163,13 @@ func BorrowInts(size int) []int { if retVal == nil { return make([]int, size) } + // log.Printf("Borrowing %p. Called by %v", retVal, string(debug.Stack())) return retVal.([]int)[:size] } // ReturnInts returns a slice from the pool. USE WITH CAUTION. func ReturnInts(is []int) { + // log.Printf("Returning %p. Called by %v", is, string(debug.Stack())) if is == nil { return } @@ -293,43 +236,13 @@ func ReturnBools(is []bool) { // boolsPool[size].Put(is) } -// BorrowAPList gets an APList from the pool. USE WITH CAUTION. -func BorrowAPList(size int) []*AP { - if size >= 8 { - return make([]*AP, size) - } - - retVal := apListPool[size].Get() - if retVal == nil { - return make([]*AP, size) - } - return retVal.([]*AP) -} - -// ReturnAPList returns the APList to the pool. USE WITH CAUTION. -func ReturnAPList(aps []*AP) { - if aps == nil { - return - } - size := cap(aps) - if size >= 8 { - return - } - aps = aps[:cap(aps)] - for i := range aps { - aps[i] = nil - } - - apListPool[size].Put(aps) -} - // var optPool = make(chan *OpOpt, PoolSize) // var optPool = newRingbuffer(PoolSize) var optPool = &sync.Pool{ - New: func() interface{} { return new(OpOpt) }, + New: func() interface{} { return new(opOpt) }, } -func borrowOpOpt() *OpOpt { +func borrowOpOpt() *opOpt { // select { // case fo := <-optPool: // return fo @@ -337,7 +250,7 @@ func borrowOpOpt() *OpOpt { // return new(OpOpt) // } - return optPool.Get().(*OpOpt) + return optPool.Get().(*opOpt) // if fo, err := optPool.Get(); err == nil { // return (*OpOpt)(fo) @@ -345,12 +258,13 @@ func borrowOpOpt() *OpOpt { // return new(OpOpt) } -func returnOpOpt(oo *OpOpt) { +func returnOpOpt(oo *opOpt) { oo.reuse = nil oo.incr = nil oo.unsafe = false oo.same = false - oo.t = Dtype{} + oo.t = dtype.Dtype{} + oo.ctx = nil // if len(optPool) < cap(optPool) { // optPool <- oo // } diff --git a/scalar.go b/scalar.go new file mode 100644 index 0000000..ee37ba0 --- /dev/null +++ b/scalar.go @@ -0,0 +1,89 @@ +// +build ignore + +package tensor + +import ( + "fmt" + "io" + "reflect" + "unsafe" + + "gorgonia.org/dtype" + + "github.com/pkg/errors" + "gorgonia.org/tensor/internal/storage" +) + +var _ Tensor = Scalar{} +var _ ScalarRep = Scalar{} +var _ ScalarRep = ScalarDense{} + +// ScalarDense wraps a *Dense to provide a typesafe alternative for a scalar to be represented in a *Dense. +type ScalarDense struct { + *Dense +} + +func (s ScalarDense) IsScalar() bool { return true } + +func (s ScalarDense) ScalarValue() interface{} { return s.Dense.Data() } + +// Scalar is a representation of a scalar value on the CPU. +type Scalar struct{ v interface{} } + +func MakeScalar(v interface{}) Scalar { + if s, ok := v.(Scalar); ok { + return s + } + if s, ok := v.(*Scalar); ok { + return Scalar{s.v} + } + return Scalar{v} +} + +func (s Scalar) Shape() Shape { return ScalarShape() } +func (s Scalar) Strides() []int { return nil } +func (s Scalar) Dtype() dtype.Dtype { return dtype.Dtype{reflect.TypeOf(s.v)} } +func (s Scalar) Dims() int { return 0 } +func (s Scalar) Size() int { return 0 } // TODO +func (s Scalar) DataSize() int { return 0 } +func (s Scalar) RequiresIterator() bool { return false } +func (s Scalar) Iterator() Iterator { return nil } +func (s Scalar) DataOrder() DataOrder { return 0 } // TODO + +func (s Scalar) Slice(...Slice) (View, error) { return nil, errors.New("Cannot slice a scalar") } +func (s Scalar) At(at ...int) (interface{}, error) { return nil, errors.New("Get a value of a scalar") } +func (s Scalar) SetAt(_ interface{}, _ ...int) error { return errors.New("Cannot set value of scalar") } +func (s Scalar) Reshape(_ ...int) error { return errors.New("Cannot reshape a scalar") } +func (s Scalar) T(_ ...int) error { return errors.New("Cannot transpose a scalar") } +func (s Scalar) UT() {} +func (s Scalar) Transpose() error { return errors.New("Cannot transpose a scalar") } +func (s Scalar) Apply(fn interface{}, opts ...FuncOpt) (Tensor, error) { return nyierr(typeNYI, s) } + +func (s Scalar) Zero() {} //TODO +func (s Scalar) Memset(interface{}) error { return errors.New("Cannot Memset") } +func (s Scalar) Data() interface{} { return s.v } +func (s Scalar) Eq(other interface{}) bool { return s == other } +func (s Scalar) Clone() interface{} { return s } + +func (s Scalar) IsScalar() bool { return true } +func (s Scalar) ScalarValue() interface{} { return s.v } + +func (s Scalar) Engine() Engine { return nil } +func (s Scalar) MemSize() uintptr { return 0 } +func (s Scalar) Uintptr() uintptr { return 0 } +func (s Scalar) Pointer() unsafe.Pointer { return nil } +func (s Scalar) IsNativelyAccessible() bool { return true } +func (s Scalar) IsManuallyManaged() bool { return false } + +func (s Scalar) Format(t fmt.State, c rune) {} // TODO +func (s Scalar) String() string { return fmt.Sprintf("%v", s) } + +func (s Scalar) WriteNpy(io.Writer) error { return nyierr(typeNYI, s) } +func (s Scalar) ReadNpy(io.Reader) error { return nyierr(typeNYI, s) } +func (s Scalar) GobEncode() ([]byte, error) { return nil, nyierr(typeNYI, s) } +func (s Scalar) GobDecode([]byte) error { return nyierr(typeNYI, s) } + +func (s Scalar) standardEngine() StandardEngine { return StdEng{} } +func (s Scalar) hdr() *storage.Header { return nil } +func (s Scalar) arr() array { return array{} } +func (s Scalar) arrPtr() *array { return nil } diff --git a/select_native.go b/select_native.go new file mode 100644 index 0000000..d3cf1f2 --- /dev/null +++ b/select_native.go @@ -0,0 +1,635 @@ +package tensor + +import ( + "reflect" + "unsafe" + + "github.com/pkg/errors" + "gorgonia.org/dtype" +) + +// Code generated by genlib2. DO NOT EDIT. + +func checkNativeSelectable(t *Dense, axis int, dt dtype.Dtype) error { + if !t.IsNativelyAccessible() { + return errors.New("Cannot select on non-natively accessible data") + } + if axis >= t.Shape().Dims() && !(t.IsScalar() && axis == 0) { + return errors.Errorf("Cannot select on axis %d. Shape is %v", axis, t.Shape()) + } + if t.F() || t.RequiresIterator() { + return errors.Errorf("Not yet implemented: native select for colmajor or unpacked matrices") + } + if t.Dtype() != dt { + return errors.Errorf("Native selection only works on %v. Got %v", dt, t.Dtype()) + } + return nil +} + +/* Native Select for bool */ + +// nativeSelectB creates a slice of flat data types. See Example of NativeSelectF64. +func nativeSelectB(t *Dense, axis int) (retVal [][]bool, err error) { + if err := checkNativeSelectable(t, axis, Bool); err != nil { + return nil, err + } + + switch t.Shape().Dims() { + case 0, 1: + retVal = make([][]bool, 1) + retVal[0] = t.Bools() + case 2: + if axis == 0 { + return nativeDenseMatrixB(t) + } + fallthrough + default: + // size := t.Shape()[axis] + data := t.Bools() + stride := t.Strides()[axis] + upper := ProdInts(t.Shape()[:axis+1]) + retVal = make([][]bool, 0, upper) + for i, r := 0, 0; r < upper; i += stride { + s := make([]bool, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&s)) + hdr.Data = uintptr(unsafe.Pointer(&data[i])) + hdr.Len = stride + hdr.Cap = stride + retVal = append(retVal, s) + r++ + } + return retVal, nil + + } + return +} + +/* Native Select for int */ + +// nativeSelectI creates a slice of flat data types. See Example of NativeSelectF64. +func nativeSelectI(t *Dense, axis int) (retVal [][]int, err error) { + if err := checkNativeSelectable(t, axis, Int); err != nil { + return nil, err + } + + switch t.Shape().Dims() { + case 0, 1: + retVal = make([][]int, 1) + retVal[0] = t.Ints() + case 2: + if axis == 0 { + return nativeDenseMatrixI(t) + } + fallthrough + default: + // size := t.Shape()[axis] + data := t.Ints() + stride := t.Strides()[axis] + upper := ProdInts(t.Shape()[:axis+1]) + retVal = make([][]int, 0, upper) + for i, r := 0, 0; r < upper; i += stride { + s := make([]int, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&s)) + hdr.Data = uintptr(unsafe.Pointer(&data[i])) + hdr.Len = stride + hdr.Cap = stride + retVal = append(retVal, s) + r++ + } + return retVal, nil + + } + return +} + +/* Native Select for int8 */ + +// nativeSelectI8 creates a slice of flat data types. See Example of NativeSelectF64. +func nativeSelectI8(t *Dense, axis int) (retVal [][]int8, err error) { + if err := checkNativeSelectable(t, axis, Int8); err != nil { + return nil, err + } + + switch t.Shape().Dims() { + case 0, 1: + retVal = make([][]int8, 1) + retVal[0] = t.Int8s() + case 2: + if axis == 0 { + return nativeDenseMatrixI8(t) + } + fallthrough + default: + // size := t.Shape()[axis] + data := t.Int8s() + stride := t.Strides()[axis] + upper := ProdInts(t.Shape()[:axis+1]) + retVal = make([][]int8, 0, upper) + for i, r := 0, 0; r < upper; i += stride { + s := make([]int8, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&s)) + hdr.Data = uintptr(unsafe.Pointer(&data[i])) + hdr.Len = stride + hdr.Cap = stride + retVal = append(retVal, s) + r++ + } + return retVal, nil + + } + return +} + +/* Native Select for int16 */ + +// nativeSelectI16 creates a slice of flat data types. See Example of NativeSelectF64. +func nativeSelectI16(t *Dense, axis int) (retVal [][]int16, err error) { + if err := checkNativeSelectable(t, axis, Int16); err != nil { + return nil, err + } + + switch t.Shape().Dims() { + case 0, 1: + retVal = make([][]int16, 1) + retVal[0] = t.Int16s() + case 2: + if axis == 0 { + return nativeDenseMatrixI16(t) + } + fallthrough + default: + // size := t.Shape()[axis] + data := t.Int16s() + stride := t.Strides()[axis] + upper := ProdInts(t.Shape()[:axis+1]) + retVal = make([][]int16, 0, upper) + for i, r := 0, 0; r < upper; i += stride { + s := make([]int16, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&s)) + hdr.Data = uintptr(unsafe.Pointer(&data[i])) + hdr.Len = stride + hdr.Cap = stride + retVal = append(retVal, s) + r++ + } + return retVal, nil + + } + return +} + +/* Native Select for int32 */ + +// nativeSelectI32 creates a slice of flat data types. See Example of NativeSelectF64. +func nativeSelectI32(t *Dense, axis int) (retVal [][]int32, err error) { + if err := checkNativeSelectable(t, axis, Int32); err != nil { + return nil, err + } + + switch t.Shape().Dims() { + case 0, 1: + retVal = make([][]int32, 1) + retVal[0] = t.Int32s() + case 2: + if axis == 0 { + return nativeDenseMatrixI32(t) + } + fallthrough + default: + // size := t.Shape()[axis] + data := t.Int32s() + stride := t.Strides()[axis] + upper := ProdInts(t.Shape()[:axis+1]) + retVal = make([][]int32, 0, upper) + for i, r := 0, 0; r < upper; i += stride { + s := make([]int32, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&s)) + hdr.Data = uintptr(unsafe.Pointer(&data[i])) + hdr.Len = stride + hdr.Cap = stride + retVal = append(retVal, s) + r++ + } + return retVal, nil + + } + return +} + +/* Native Select for int64 */ + +// nativeSelectI64 creates a slice of flat data types. See Example of NativeSelectF64. +func nativeSelectI64(t *Dense, axis int) (retVal [][]int64, err error) { + if err := checkNativeSelectable(t, axis, Int64); err != nil { + return nil, err + } + + switch t.Shape().Dims() { + case 0, 1: + retVal = make([][]int64, 1) + retVal[0] = t.Int64s() + case 2: + if axis == 0 { + return nativeDenseMatrixI64(t) + } + fallthrough + default: + // size := t.Shape()[axis] + data := t.Int64s() + stride := t.Strides()[axis] + upper := ProdInts(t.Shape()[:axis+1]) + retVal = make([][]int64, 0, upper) + for i, r := 0, 0; r < upper; i += stride { + s := make([]int64, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&s)) + hdr.Data = uintptr(unsafe.Pointer(&data[i])) + hdr.Len = stride + hdr.Cap = stride + retVal = append(retVal, s) + r++ + } + return retVal, nil + + } + return +} + +/* Native Select for uint */ + +// nativeSelectU creates a slice of flat data types. See Example of NativeSelectF64. +func nativeSelectU(t *Dense, axis int) (retVal [][]uint, err error) { + if err := checkNativeSelectable(t, axis, Uint); err != nil { + return nil, err + } + + switch t.Shape().Dims() { + case 0, 1: + retVal = make([][]uint, 1) + retVal[0] = t.Uints() + case 2: + if axis == 0 { + return nativeDenseMatrixU(t) + } + fallthrough + default: + // size := t.Shape()[axis] + data := t.Uints() + stride := t.Strides()[axis] + upper := ProdInts(t.Shape()[:axis+1]) + retVal = make([][]uint, 0, upper) + for i, r := 0, 0; r < upper; i += stride { + s := make([]uint, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&s)) + hdr.Data = uintptr(unsafe.Pointer(&data[i])) + hdr.Len = stride + hdr.Cap = stride + retVal = append(retVal, s) + r++ + } + return retVal, nil + + } + return +} + +/* Native Select for uint8 */ + +// nativeSelectU8 creates a slice of flat data types. See Example of NativeSelectF64. +func nativeSelectU8(t *Dense, axis int) (retVal [][]uint8, err error) { + if err := checkNativeSelectable(t, axis, Uint8); err != nil { + return nil, err + } + + switch t.Shape().Dims() { + case 0, 1: + retVal = make([][]uint8, 1) + retVal[0] = t.Uint8s() + case 2: + if axis == 0 { + return nativeDenseMatrixU8(t) + } + fallthrough + default: + // size := t.Shape()[axis] + data := t.Uint8s() + stride := t.Strides()[axis] + upper := ProdInts(t.Shape()[:axis+1]) + retVal = make([][]uint8, 0, upper) + for i, r := 0, 0; r < upper; i += stride { + s := make([]uint8, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&s)) + hdr.Data = uintptr(unsafe.Pointer(&data[i])) + hdr.Len = stride + hdr.Cap = stride + retVal = append(retVal, s) + r++ + } + return retVal, nil + + } + return +} + +/* Native Select for uint16 */ + +// nativeSelectU16 creates a slice of flat data types. See Example of NativeSelectF64. +func nativeSelectU16(t *Dense, axis int) (retVal [][]uint16, err error) { + if err := checkNativeSelectable(t, axis, Uint16); err != nil { + return nil, err + } + + switch t.Shape().Dims() { + case 0, 1: + retVal = make([][]uint16, 1) + retVal[0] = t.Uint16s() + case 2: + if axis == 0 { + return nativeDenseMatrixU16(t) + } + fallthrough + default: + // size := t.Shape()[axis] + data := t.Uint16s() + stride := t.Strides()[axis] + upper := ProdInts(t.Shape()[:axis+1]) + retVal = make([][]uint16, 0, upper) + for i, r := 0, 0; r < upper; i += stride { + s := make([]uint16, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&s)) + hdr.Data = uintptr(unsafe.Pointer(&data[i])) + hdr.Len = stride + hdr.Cap = stride + retVal = append(retVal, s) + r++ + } + return retVal, nil + + } + return +} + +/* Native Select for uint32 */ + +// nativeSelectU32 creates a slice of flat data types. See Example of NativeSelectF64. +func nativeSelectU32(t *Dense, axis int) (retVal [][]uint32, err error) { + if err := checkNativeSelectable(t, axis, Uint32); err != nil { + return nil, err + } + + switch t.Shape().Dims() { + case 0, 1: + retVal = make([][]uint32, 1) + retVal[0] = t.Uint32s() + case 2: + if axis == 0 { + return nativeDenseMatrixU32(t) + } + fallthrough + default: + // size := t.Shape()[axis] + data := t.Uint32s() + stride := t.Strides()[axis] + upper := ProdInts(t.Shape()[:axis+1]) + retVal = make([][]uint32, 0, upper) + for i, r := 0, 0; r < upper; i += stride { + s := make([]uint32, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&s)) + hdr.Data = uintptr(unsafe.Pointer(&data[i])) + hdr.Len = stride + hdr.Cap = stride + retVal = append(retVal, s) + r++ + } + return retVal, nil + + } + return +} + +/* Native Select for uint64 */ + +// nativeSelectU64 creates a slice of flat data types. See Example of NativeSelectF64. +func nativeSelectU64(t *Dense, axis int) (retVal [][]uint64, err error) { + if err := checkNativeSelectable(t, axis, Uint64); err != nil { + return nil, err + } + + switch t.Shape().Dims() { + case 0, 1: + retVal = make([][]uint64, 1) + retVal[0] = t.Uint64s() + case 2: + if axis == 0 { + return nativeDenseMatrixU64(t) + } + fallthrough + default: + // size := t.Shape()[axis] + data := t.Uint64s() + stride := t.Strides()[axis] + upper := ProdInts(t.Shape()[:axis+1]) + retVal = make([][]uint64, 0, upper) + for i, r := 0, 0; r < upper; i += stride { + s := make([]uint64, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&s)) + hdr.Data = uintptr(unsafe.Pointer(&data[i])) + hdr.Len = stride + hdr.Cap = stride + retVal = append(retVal, s) + r++ + } + return retVal, nil + + } + return +} + +/* Native Select for float32 */ + +// nativeSelectF32 creates a slice of flat data types. See Example of NativeSelectF64. +func nativeSelectF32(t *Dense, axis int) (retVal [][]float32, err error) { + if err := checkNativeSelectable(t, axis, Float32); err != nil { + return nil, err + } + + switch t.Shape().Dims() { + case 0, 1: + retVal = make([][]float32, 1) + retVal[0] = t.Float32s() + case 2: + if axis == 0 { + return nativeDenseMatrixF32(t) + } + fallthrough + default: + // size := t.Shape()[axis] + data := t.Float32s() + stride := t.Strides()[axis] + upper := ProdInts(t.Shape()[:axis+1]) + retVal = make([][]float32, 0, upper) + for i, r := 0, 0; r < upper; i += stride { + s := make([]float32, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&s)) + hdr.Data = uintptr(unsafe.Pointer(&data[i])) + hdr.Len = stride + hdr.Cap = stride + retVal = append(retVal, s) + r++ + } + return retVal, nil + + } + return +} + +/* Native Select for float64 */ + +// nativeSelectF64 creates a slice of flat data types. See Example of NativeSelectF64. +func nativeSelectF64(t *Dense, axis int) (retVal [][]float64, err error) { + if err := checkNativeSelectable(t, axis, Float64); err != nil { + return nil, err + } + + switch t.Shape().Dims() { + case 0, 1: + retVal = make([][]float64, 1) + retVal[0] = t.Float64s() + case 2: + if axis == 0 { + return nativeDenseMatrixF64(t) + } + fallthrough + default: + // size := t.Shape()[axis] + data := t.Float64s() + stride := t.Strides()[axis] + upper := ProdInts(t.Shape()[:axis+1]) + retVal = make([][]float64, 0, upper) + for i, r := 0, 0; r < upper; i += stride { + s := make([]float64, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&s)) + hdr.Data = uintptr(unsafe.Pointer(&data[i])) + hdr.Len = stride + hdr.Cap = stride + retVal = append(retVal, s) + r++ + } + return retVal, nil + + } + return +} + +/* Native Select for complex64 */ + +// nativeSelectC64 creates a slice of flat data types. See Example of NativeSelectF64. +func nativeSelectC64(t *Dense, axis int) (retVal [][]complex64, err error) { + if err := checkNativeSelectable(t, axis, Complex64); err != nil { + return nil, err + } + + switch t.Shape().Dims() { + case 0, 1: + retVal = make([][]complex64, 1) + retVal[0] = t.Complex64s() + case 2: + if axis == 0 { + return nativeDenseMatrixC64(t) + } + fallthrough + default: + // size := t.Shape()[axis] + data := t.Complex64s() + stride := t.Strides()[axis] + upper := ProdInts(t.Shape()[:axis+1]) + retVal = make([][]complex64, 0, upper) + for i, r := 0, 0; r < upper; i += stride { + s := make([]complex64, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&s)) + hdr.Data = uintptr(unsafe.Pointer(&data[i])) + hdr.Len = stride + hdr.Cap = stride + retVal = append(retVal, s) + r++ + } + return retVal, nil + + } + return +} + +/* Native Select for complex128 */ + +// nativeSelectC128 creates a slice of flat data types. See Example of NativeSelectF64. +func nativeSelectC128(t *Dense, axis int) (retVal [][]complex128, err error) { + if err := checkNativeSelectable(t, axis, Complex128); err != nil { + return nil, err + } + + switch t.Shape().Dims() { + case 0, 1: + retVal = make([][]complex128, 1) + retVal[0] = t.Complex128s() + case 2: + if axis == 0 { + return nativeDenseMatrixC128(t) + } + fallthrough + default: + // size := t.Shape()[axis] + data := t.Complex128s() + stride := t.Strides()[axis] + upper := ProdInts(t.Shape()[:axis+1]) + retVal = make([][]complex128, 0, upper) + for i, r := 0, 0; r < upper; i += stride { + s := make([]complex128, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&s)) + hdr.Data = uintptr(unsafe.Pointer(&data[i])) + hdr.Len = stride + hdr.Cap = stride + retVal = append(retVal, s) + r++ + } + return retVal, nil + + } + return +} + +/* Native Select for string */ + +// nativeSelectStr creates a slice of flat data types. See Example of NativeSelectF64. +func nativeSelectStr(t *Dense, axis int) (retVal [][]string, err error) { + if err := checkNativeSelectable(t, axis, String); err != nil { + return nil, err + } + + switch t.Shape().Dims() { + case 0, 1: + retVal = make([][]string, 1) + retVal[0] = t.Strings() + case 2: + if axis == 0 { + return nativeDenseMatrixStr(t) + } + fallthrough + default: + // size := t.Shape()[axis] + data := t.Strings() + stride := t.Strides()[axis] + upper := ProdInts(t.Shape()[:axis+1]) + retVal = make([][]string, 0, upper) + for i, r := 0, 0; r < upper; i += stride { + s := make([]string, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&s)) + hdr.Data = uintptr(unsafe.Pointer(&data[i])) + hdr.Len = stride + hdr.Cap = stride + retVal = append(retVal, s) + r++ + } + return retVal, nil + + } + return +} diff --git a/select_native_batched.go b/select_native_batched.go new file mode 100644 index 0000000..c05bd76 --- /dev/null +++ b/select_native_batched.go @@ -0,0 +1,163 @@ +package tensor + +import ( + "reflect" + "runtime" + "unsafe" +) + +type BatchedNativeSelectF64 struct { + t *Dense + it [][]float64 // FUTURE: this can be made into generic in the future + + // state + + upper int // the outer dimension after being "reshaped" + limit int // limit as to how many rows the `it` can store + stride int // stride + r int // current row +} + +func BatchSelectF64(t *Dense, axis int, limit int) *BatchedNativeSelectF64 { + if err := checkNativeSelectable(t, axis, Float64); err != nil { + panic(err) + } + + if limit <= 0 { + limit = runtime.NumCPU() // default + } + upper := ProdInts(t.Shape()[:axis+1]) + if limit > upper { + limit = upper + // `it` should come from nativeSelectF64 + } + stride := t.Strides()[axis] + data := t.Float64s() + + it := make([][]float64, 0, limit) + var i, r int + for i, r = 0, 0; r < limit; i += stride { + // this block of code is basically + // it = append(it, data[i:i+stride]) + // TODO: benchmark + it = append(it, make([]float64, 0)) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&it[len(it)-1])) + hdr.Data = uintptr(unsafe.Pointer(&data[i])) + hdr.Len = stride + hdr.Cap = stride + r++ + } + + return &BatchedNativeSelectF64{ + t: t, + it: it, + upper: upper, + limit: limit, + stride: stride, + r: r, + } +} + +func (it *BatchedNativeSelectF64) Start() (curBatch [][]float64, hasRemainingRows bool) { + if it.r != it.limit || it.IsTruncated() { + // then it's been moved, so we reset + it.Reset() + } + curBatch = it.it + hasRemainingRows = it.upper > it.r + return +} + +// Next moves the next batch into the native iterator. +func (it *BatchedNativeSelectF64) Next() (curBatch [][]float64, hasRemaingRows bool) { + var ( + i int // data ptr + r int // relative row / row counter for this batch + s int // absolute row + ) + if it.r == it.upper { + return it.it, false + } + data := it.t.Float64s() + + // this loop statement looks scary. But it isn't. Let me break it down: + // Initialization: + // i := it.r*it.stride // the data pointer is the row number * the stride of the matrix. + // r := 0 // loop counter. We're gonna iterate `it.limit` times. + // s := it.r // the current row number of the matrix. + // Condition (continue if the following are true): + // r < it.limit // we only want to iterate at most `it.limit` times. + // s < it.upper // we want to make sure we don't iterate more rows than there are rows in the matrix. + // Next: + // i = i + it.stride // we're ready to go to the next row. + // r = r+1 // we increment the row counter. + // s = s+1 // we increment the absolute row number. + // + // Could this be written in a less concise way? Sure. But then there'd be a lot more places to keep track of things. + for i, r, s = it.r*it.stride, 0, it.r; r < it.limit && s < it.upper; i, r, s = i+it.stride, r+1, s+1 { + // the block of code below is basically: + // it.it[r] = data[i:i+stride] + // r++ + // For some reason when this is done, Go actually does a lot more allocations. + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&it.it[r])) + hdr.Data = uintptr(unsafe.Pointer(&data[i])) + } + it.r = s + + if it.r == it.upper && r < it.limit { + // truncate it.it because iterated rows is less than the limit. + // This implies that there are some extra rows. + it.it = it.it[:r] + } + + return it.it, true +} + +func (it *BatchedNativeSelectF64) Native() [][]float64 { return it.it } + +func (it *BatchedNativeSelectF64) Reset() { + it.it = it.it[:it.limit:it.limit] + + data := it.t.Float64s() + var i, r int + for i, r = 0, 0; r < it.limit; i += it.stride { + sl := it.it[r] + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&sl)) + hdr.Data = uintptr(unsafe.Pointer(&data[i])) + hdr.Len = it.stride + hdr.Cap = it.stride + it.it[r] = sl + r++ + } + it.r = r +} + +func (it *BatchedNativeSelectF64) IsTruncated() bool { return len(it.it) != it.limit } + +type IterSelect struct { + r int + upper int + stride int + total int +} + +func NewIterSelect(t *Dense, axis int) *IterSelect { + upper := ProdInts(t.Shape()[:axis+1]) + stride := t.Strides()[axis] + total := t.DataSize() + return &IterSelect{upper: upper, stride: stride, total: total} +} + +func (it *IterSelect) Start() (start, end int, hasRem bool) { + if it.r > it.stride { + it.Reset() + } + return it.r, it.stride, it.r*it.stride+it.stride < it.total +} + +func (it *IterSelect) Next() (start, end int, hasRem bool) { + it.r += it.stride + return it.r, it.r + it.stride, it.r+it.stride <= it.total +} + +func (it *IterSelect) Reset() { it.r = 0 } diff --git a/select_native_test.go b/select_native_test.go new file mode 100644 index 0000000..02291b5 --- /dev/null +++ b/select_native_test.go @@ -0,0 +1,841 @@ +package tensor + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +// Code generated by genlib2. DO NOT EDIT. + +func TestnativeSelectB(t *testing.T) { + assert := assert.New(t) + var T *Dense + var err error + var x [][]bool + T = New(Of(Bool), WithShape(2, 3, 4, 5)) + if x, err = nativeSelectB(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(20, len(x[0])) + + T = New(Of(Bool), WithShape(2, 3, 4, 5)) + if x, err = nativeSelectB(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(60, len(x[0])) + + T = New(Of(Bool), WithShape(2, 3, 4, 5)) + if x, err = nativeSelectB(T, 3); err != nil { + t.Fatal(err) + } + assert.Equal(120, len(x)) + assert.Equal(1, len(x[0])) + + T = New(Of(Bool), WithShape(2, 3)) + if x, err = nativeSelectB(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(3, len(x[0])) + + T = New(Of(Bool), WithShape(2, 3)) + if x, err = nativeSelectB(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(1, len(x[0])) + + T = New(FromScalar(false)) + if x, err = nativeSelectB(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(1, len(x)) + assert.Equal(1, len(x[0])) + + if _, err = nativeSelectB(T, 10); err == nil { + t.Fatal("Expected errors") + } +} + +func TestnativeSelectI(t *testing.T) { + assert := assert.New(t) + var T *Dense + var err error + var x [][]int + T = New(Of(Int), WithShape(2, 3, 4, 5)) + if x, err = nativeSelectI(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(20, len(x[0])) + + T = New(Of(Int), WithShape(2, 3, 4, 5)) + if x, err = nativeSelectI(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(60, len(x[0])) + + T = New(Of(Int), WithShape(2, 3, 4, 5)) + if x, err = nativeSelectI(T, 3); err != nil { + t.Fatal(err) + } + assert.Equal(120, len(x)) + assert.Equal(1, len(x[0])) + + T = New(Of(Int), WithShape(2, 3)) + if x, err = nativeSelectI(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(3, len(x[0])) + + T = New(Of(Int), WithShape(2, 3)) + if x, err = nativeSelectI(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(1, len(x[0])) + + T = New(FromScalar(int(0))) + if x, err = nativeSelectI(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(1, len(x)) + assert.Equal(1, len(x[0])) + + if _, err = nativeSelectI(T, 10); err == nil { + t.Fatal("Expected errors") + } +} + +func TestnativeSelectI8(t *testing.T) { + assert := assert.New(t) + var T *Dense + var err error + var x [][]int8 + T = New(Of(Int8), WithShape(2, 3, 4, 5)) + if x, err = nativeSelectI8(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(20, len(x[0])) + + T = New(Of(Int8), WithShape(2, 3, 4, 5)) + if x, err = nativeSelectI8(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(60, len(x[0])) + + T = New(Of(Int8), WithShape(2, 3, 4, 5)) + if x, err = nativeSelectI8(T, 3); err != nil { + t.Fatal(err) + } + assert.Equal(120, len(x)) + assert.Equal(1, len(x[0])) + + T = New(Of(Int8), WithShape(2, 3)) + if x, err = nativeSelectI8(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(3, len(x[0])) + + T = New(Of(Int8), WithShape(2, 3)) + if x, err = nativeSelectI8(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(1, len(x[0])) + + T = New(FromScalar(int8(0))) + if x, err = nativeSelectI8(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(1, len(x)) + assert.Equal(1, len(x[0])) + + if _, err = nativeSelectI8(T, 10); err == nil { + t.Fatal("Expected errors") + } +} + +func TestnativeSelectI16(t *testing.T) { + assert := assert.New(t) + var T *Dense + var err error + var x [][]int16 + T = New(Of(Int16), WithShape(2, 3, 4, 5)) + if x, err = nativeSelectI16(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(20, len(x[0])) + + T = New(Of(Int16), WithShape(2, 3, 4, 5)) + if x, err = nativeSelectI16(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(60, len(x[0])) + + T = New(Of(Int16), WithShape(2, 3, 4, 5)) + if x, err = nativeSelectI16(T, 3); err != nil { + t.Fatal(err) + } + assert.Equal(120, len(x)) + assert.Equal(1, len(x[0])) + + T = New(Of(Int16), WithShape(2, 3)) + if x, err = nativeSelectI16(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(3, len(x[0])) + + T = New(Of(Int16), WithShape(2, 3)) + if x, err = nativeSelectI16(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(1, len(x[0])) + + T = New(FromScalar(int16(0))) + if x, err = nativeSelectI16(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(1, len(x)) + assert.Equal(1, len(x[0])) + + if _, err = nativeSelectI16(T, 10); err == nil { + t.Fatal("Expected errors") + } +} + +func TestnativeSelectI32(t *testing.T) { + assert := assert.New(t) + var T *Dense + var err error + var x [][]int32 + T = New(Of(Int32), WithShape(2, 3, 4, 5)) + if x, err = nativeSelectI32(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(20, len(x[0])) + + T = New(Of(Int32), WithShape(2, 3, 4, 5)) + if x, err = nativeSelectI32(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(60, len(x[0])) + + T = New(Of(Int32), WithShape(2, 3, 4, 5)) + if x, err = nativeSelectI32(T, 3); err != nil { + t.Fatal(err) + } + assert.Equal(120, len(x)) + assert.Equal(1, len(x[0])) + + T = New(Of(Int32), WithShape(2, 3)) + if x, err = nativeSelectI32(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(3, len(x[0])) + + T = New(Of(Int32), WithShape(2, 3)) + if x, err = nativeSelectI32(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(1, len(x[0])) + + T = New(FromScalar(int32(0))) + if x, err = nativeSelectI32(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(1, len(x)) + assert.Equal(1, len(x[0])) + + if _, err = nativeSelectI32(T, 10); err == nil { + t.Fatal("Expected errors") + } +} + +func TestnativeSelectI64(t *testing.T) { + assert := assert.New(t) + var T *Dense + var err error + var x [][]int64 + T = New(Of(Int64), WithShape(2, 3, 4, 5)) + if x, err = nativeSelectI64(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(20, len(x[0])) + + T = New(Of(Int64), WithShape(2, 3, 4, 5)) + if x, err = nativeSelectI64(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(60, len(x[0])) + + T = New(Of(Int64), WithShape(2, 3, 4, 5)) + if x, err = nativeSelectI64(T, 3); err != nil { + t.Fatal(err) + } + assert.Equal(120, len(x)) + assert.Equal(1, len(x[0])) + + T = New(Of(Int64), WithShape(2, 3)) + if x, err = nativeSelectI64(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(3, len(x[0])) + + T = New(Of(Int64), WithShape(2, 3)) + if x, err = nativeSelectI64(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(1, len(x[0])) + + T = New(FromScalar(int64(0))) + if x, err = nativeSelectI64(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(1, len(x)) + assert.Equal(1, len(x[0])) + + if _, err = nativeSelectI64(T, 10); err == nil { + t.Fatal("Expected errors") + } +} + +func TestnativeSelectU(t *testing.T) { + assert := assert.New(t) + var T *Dense + var err error + var x [][]uint + T = New(Of(Uint), WithShape(2, 3, 4, 5)) + if x, err = nativeSelectU(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(20, len(x[0])) + + T = New(Of(Uint), WithShape(2, 3, 4, 5)) + if x, err = nativeSelectU(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(60, len(x[0])) + + T = New(Of(Uint), WithShape(2, 3, 4, 5)) + if x, err = nativeSelectU(T, 3); err != nil { + t.Fatal(err) + } + assert.Equal(120, len(x)) + assert.Equal(1, len(x[0])) + + T = New(Of(Uint), WithShape(2, 3)) + if x, err = nativeSelectU(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(3, len(x[0])) + + T = New(Of(Uint), WithShape(2, 3)) + if x, err = nativeSelectU(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(1, len(x[0])) + + T = New(FromScalar(uint(0))) + if x, err = nativeSelectU(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(1, len(x)) + assert.Equal(1, len(x[0])) + + if _, err = nativeSelectU(T, 10); err == nil { + t.Fatal("Expected errors") + } +} + +func TestnativeSelectU8(t *testing.T) { + assert := assert.New(t) + var T *Dense + var err error + var x [][]uint8 + T = New(Of(Uint8), WithShape(2, 3, 4, 5)) + if x, err = nativeSelectU8(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(20, len(x[0])) + + T = New(Of(Uint8), WithShape(2, 3, 4, 5)) + if x, err = nativeSelectU8(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(60, len(x[0])) + + T = New(Of(Uint8), WithShape(2, 3, 4, 5)) + if x, err = nativeSelectU8(T, 3); err != nil { + t.Fatal(err) + } + assert.Equal(120, len(x)) + assert.Equal(1, len(x[0])) + + T = New(Of(Uint8), WithShape(2, 3)) + if x, err = nativeSelectU8(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(3, len(x[0])) + + T = New(Of(Uint8), WithShape(2, 3)) + if x, err = nativeSelectU8(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(1, len(x[0])) + + T = New(FromScalar(uint8(0))) + if x, err = nativeSelectU8(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(1, len(x)) + assert.Equal(1, len(x[0])) + + if _, err = nativeSelectU8(T, 10); err == nil { + t.Fatal("Expected errors") + } +} + +func TestnativeSelectU16(t *testing.T) { + assert := assert.New(t) + var T *Dense + var err error + var x [][]uint16 + T = New(Of(Uint16), WithShape(2, 3, 4, 5)) + if x, err = nativeSelectU16(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(20, len(x[0])) + + T = New(Of(Uint16), WithShape(2, 3, 4, 5)) + if x, err = nativeSelectU16(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(60, len(x[0])) + + T = New(Of(Uint16), WithShape(2, 3, 4, 5)) + if x, err = nativeSelectU16(T, 3); err != nil { + t.Fatal(err) + } + assert.Equal(120, len(x)) + assert.Equal(1, len(x[0])) + + T = New(Of(Uint16), WithShape(2, 3)) + if x, err = nativeSelectU16(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(3, len(x[0])) + + T = New(Of(Uint16), WithShape(2, 3)) + if x, err = nativeSelectU16(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(1, len(x[0])) + + T = New(FromScalar(uint16(0))) + if x, err = nativeSelectU16(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(1, len(x)) + assert.Equal(1, len(x[0])) + + if _, err = nativeSelectU16(T, 10); err == nil { + t.Fatal("Expected errors") + } +} + +func TestnativeSelectU32(t *testing.T) { + assert := assert.New(t) + var T *Dense + var err error + var x [][]uint32 + T = New(Of(Uint32), WithShape(2, 3, 4, 5)) + if x, err = nativeSelectU32(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(20, len(x[0])) + + T = New(Of(Uint32), WithShape(2, 3, 4, 5)) + if x, err = nativeSelectU32(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(60, len(x[0])) + + T = New(Of(Uint32), WithShape(2, 3, 4, 5)) + if x, err = nativeSelectU32(T, 3); err != nil { + t.Fatal(err) + } + assert.Equal(120, len(x)) + assert.Equal(1, len(x[0])) + + T = New(Of(Uint32), WithShape(2, 3)) + if x, err = nativeSelectU32(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(3, len(x[0])) + + T = New(Of(Uint32), WithShape(2, 3)) + if x, err = nativeSelectU32(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(1, len(x[0])) + + T = New(FromScalar(uint32(0))) + if x, err = nativeSelectU32(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(1, len(x)) + assert.Equal(1, len(x[0])) + + if _, err = nativeSelectU32(T, 10); err == nil { + t.Fatal("Expected errors") + } +} + +func TestnativeSelectU64(t *testing.T) { + assert := assert.New(t) + var T *Dense + var err error + var x [][]uint64 + T = New(Of(Uint64), WithShape(2, 3, 4, 5)) + if x, err = nativeSelectU64(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(20, len(x[0])) + + T = New(Of(Uint64), WithShape(2, 3, 4, 5)) + if x, err = nativeSelectU64(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(60, len(x[0])) + + T = New(Of(Uint64), WithShape(2, 3, 4, 5)) + if x, err = nativeSelectU64(T, 3); err != nil { + t.Fatal(err) + } + assert.Equal(120, len(x)) + assert.Equal(1, len(x[0])) + + T = New(Of(Uint64), WithShape(2, 3)) + if x, err = nativeSelectU64(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(3, len(x[0])) + + T = New(Of(Uint64), WithShape(2, 3)) + if x, err = nativeSelectU64(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(1, len(x[0])) + + T = New(FromScalar(uint64(0))) + if x, err = nativeSelectU64(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(1, len(x)) + assert.Equal(1, len(x[0])) + + if _, err = nativeSelectU64(T, 10); err == nil { + t.Fatal("Expected errors") + } +} + +func TestnativeSelectF32(t *testing.T) { + assert := assert.New(t) + var T *Dense + var err error + var x [][]float32 + T = New(Of(Float32), WithShape(2, 3, 4, 5)) + if x, err = nativeSelectF32(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(20, len(x[0])) + + T = New(Of(Float32), WithShape(2, 3, 4, 5)) + if x, err = nativeSelectF32(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(60, len(x[0])) + + T = New(Of(Float32), WithShape(2, 3, 4, 5)) + if x, err = nativeSelectF32(T, 3); err != nil { + t.Fatal(err) + } + assert.Equal(120, len(x)) + assert.Equal(1, len(x[0])) + + T = New(Of(Float32), WithShape(2, 3)) + if x, err = nativeSelectF32(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(3, len(x[0])) + + T = New(Of(Float32), WithShape(2, 3)) + if x, err = nativeSelectF32(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(1, len(x[0])) + + T = New(FromScalar(float32(0))) + if x, err = nativeSelectF32(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(1, len(x)) + assert.Equal(1, len(x[0])) + + if _, err = nativeSelectF32(T, 10); err == nil { + t.Fatal("Expected errors") + } +} + +func TestnativeSelectF64(t *testing.T) { + assert := assert.New(t) + var T *Dense + var err error + var x [][]float64 + T = New(Of(Float64), WithShape(2, 3, 4, 5)) + if x, err = nativeSelectF64(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(20, len(x[0])) + + T = New(Of(Float64), WithShape(2, 3, 4, 5)) + if x, err = nativeSelectF64(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(60, len(x[0])) + + T = New(Of(Float64), WithShape(2, 3, 4, 5)) + if x, err = nativeSelectF64(T, 3); err != nil { + t.Fatal(err) + } + assert.Equal(120, len(x)) + assert.Equal(1, len(x[0])) + + T = New(Of(Float64), WithShape(2, 3)) + if x, err = nativeSelectF64(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(3, len(x[0])) + + T = New(Of(Float64), WithShape(2, 3)) + if x, err = nativeSelectF64(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(1, len(x[0])) + + T = New(FromScalar(float64(0))) + if x, err = nativeSelectF64(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(1, len(x)) + assert.Equal(1, len(x[0])) + + if _, err = nativeSelectF64(T, 10); err == nil { + t.Fatal("Expected errors") + } +} + +func TestnativeSelectC64(t *testing.T) { + assert := assert.New(t) + var T *Dense + var err error + var x [][]complex64 + T = New(Of(Complex64), WithShape(2, 3, 4, 5)) + if x, err = nativeSelectC64(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(20, len(x[0])) + + T = New(Of(Complex64), WithShape(2, 3, 4, 5)) + if x, err = nativeSelectC64(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(60, len(x[0])) + + T = New(Of(Complex64), WithShape(2, 3, 4, 5)) + if x, err = nativeSelectC64(T, 3); err != nil { + t.Fatal(err) + } + assert.Equal(120, len(x)) + assert.Equal(1, len(x[0])) + + T = New(Of(Complex64), WithShape(2, 3)) + if x, err = nativeSelectC64(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(3, len(x[0])) + + T = New(Of(Complex64), WithShape(2, 3)) + if x, err = nativeSelectC64(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(1, len(x[0])) + + T = New(FromScalar(complex64(0))) + if x, err = nativeSelectC64(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(1, len(x)) + assert.Equal(1, len(x[0])) + + if _, err = nativeSelectC64(T, 10); err == nil { + t.Fatal("Expected errors") + } +} + +func TestnativeSelectC128(t *testing.T) { + assert := assert.New(t) + var T *Dense + var err error + var x [][]complex128 + T = New(Of(Complex128), WithShape(2, 3, 4, 5)) + if x, err = nativeSelectC128(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(20, len(x[0])) + + T = New(Of(Complex128), WithShape(2, 3, 4, 5)) + if x, err = nativeSelectC128(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(60, len(x[0])) + + T = New(Of(Complex128), WithShape(2, 3, 4, 5)) + if x, err = nativeSelectC128(T, 3); err != nil { + t.Fatal(err) + } + assert.Equal(120, len(x)) + assert.Equal(1, len(x[0])) + + T = New(Of(Complex128), WithShape(2, 3)) + if x, err = nativeSelectC128(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(3, len(x[0])) + + T = New(Of(Complex128), WithShape(2, 3)) + if x, err = nativeSelectC128(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(1, len(x[0])) + + T = New(FromScalar(complex128(0))) + if x, err = nativeSelectC128(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(1, len(x)) + assert.Equal(1, len(x[0])) + + if _, err = nativeSelectC128(T, 10); err == nil { + t.Fatal("Expected errors") + } +} + +func TestnativeSelectStr(t *testing.T) { + assert := assert.New(t) + var T *Dense + var err error + var x [][]string + T = New(Of(String), WithShape(2, 3, 4, 5)) + if x, err = nativeSelectStr(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(20, len(x[0])) + + T = New(Of(String), WithShape(2, 3, 4, 5)) + if x, err = nativeSelectStr(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(60, len(x[0])) + + T = New(Of(String), WithShape(2, 3, 4, 5)) + if x, err = nativeSelectStr(T, 3); err != nil { + t.Fatal(err) + } + assert.Equal(120, len(x)) + assert.Equal(1, len(x[0])) + + T = New(Of(String), WithShape(2, 3)) + if x, err = nativeSelectStr(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(3, len(x[0])) + + T = New(Of(String), WithShape(2, 3)) + if x, err = nativeSelectStr(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(1, len(x[0])) + + T = New(FromScalar("")) + if x, err = nativeSelectStr(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(1, len(x)) + assert.Equal(1, len(x[0])) + + if _, err = nativeSelectStr(T, 10); err == nil { + t.Fatal("Expected errors") + } +} diff --git a/shape.go b/shape.go index ba0b18f..f8d5d0a 100644 --- a/shape.go +++ b/shape.go @@ -1,9 +1,7 @@ package tensor import ( - "fmt" - - "github.com/pkg/errors" + "gorgonia.org/shapes" ) var scalarShape = Shape{} @@ -11,30 +9,21 @@ var scalarShape = Shape{} // ScalarShape represents a scalar. It has no dimensions, no sizes func ScalarShape() Shape { return scalarShape } -// Shape represents the dimensions of a Tensor. A (2,3) matrix has a shape of (2,3) - 2 rows, 3 columns. -// Likewise, a shape of (2,3,4) means a Tensor has 3 dimensions: 2 layers, 3 rows, 4 columns. -// -// Vectors are of particular note. This package defines a shape of (x, 1) as a column vector and -// a (1, x) as a row vector. Row vectors and column vectors are matrices as well. It is important to note that -// row and column vectors and vanilla vectors are comparable under some circumstances -type Shape []int - -// TotalSize returns the number of elements expected in a Tensor of a certain shape -func (s Shape) TotalSize() int { - return ProdInts([]int(s)) -} +// Shape represents a Shape. See the package shapes +type Shape = shapes.Shape -func (s Shape) calcStrides() []int { +// CalcStrides calculates the default strides for a shape +func CalcStrides(s Shape) []int { if s.IsScalar() { return nil } retVal := BorrowInts(len(s)) - if s.IsVector() { - retVal[0] = 1 - retVal = retVal[:1] - return retVal - } + // if s.IsVector() { + // retVal[0] = 1 + // retVal = retVal[:1] + // return retVal + // } acc := 1 for i := len(s) - 1; i >= 0; i-- { @@ -48,10 +37,10 @@ func (s Shape) calcStrides() []int { return retVal } -// calcStridesWithMask is similar to calcStrides, except that it has an argument, masks. It is used to mask out given dimensions +// CalcStridesWithMask is similar to CalcStrides, except that it has an argument, masks. It is used to mask out given dimensions // during calculation of stride -func (s Shape) calcStridesWithMask(mask []bool) []int { - if s.IsScalar() { +func CalcStridesWithMask(s Shape, mask []bool) []int { + if s.IsScalarEquiv() { return nil } @@ -84,8 +73,9 @@ func (s Shape) calcStridesWithMask(mask []bool) []int { return retVal } -func (s Shape) calcStridesColMajor() []int { - if s.IsScalar() { +// CalcStridesColMajor is like CalcStrides, but assumes a col major layout +func CalcStridesColMajor(s Shape) []int { + if s.IsScalarEquiv() { return nil } @@ -108,246 +98,37 @@ func (s Shape) calcStridesColMajor() []int { return retVal } -// Eq indicates if a shape is equal with another. There is a soft concept of equality when it comes to vectors. +// asMat returns a matrix shape from the given shape and axis. The given axis is which dim it will stop in. // -// If s is a column vector and other is a vanilla vector, they're considered equal if the size of the column dimension is the same as the vector size; -// if s is a row vector and other is a vanilla vector, they're considered equal if the size of the row dimension is the same as the vector size -func (s Shape) Eq(other Shape) bool { - if s.IsScalar() && other.IsScalar() { - return true - } - - if s.IsVector() && other.IsVector() { - switch { - case len(s) == 2 && len(other) == 1: - if (s.IsColVec() && s[0] == other[0]) || (s.IsRowVec() && s[1] == other[0]) { - return true - } - return false - case len(s) == 1 && len(other) == 2: - if (other.IsColVec() && other[0] == s[0]) || (other.IsRowVec() && other[1] == s[0]) { - return true - } - return false - } - } - - if len(s) != len(other) { - return false - } - - for i, v := range s { - if other[i] != v { - return false - } - } - return true -} - -// Clone clones a shape. -func (s Shape) Clone() Shape { - retVal := BorrowInts(len(s)) - copy(retVal, s) - return retVal -} - -// IsScalar returns true if the access pattern indicates it's a scalar value -func (s Shape) IsScalar() bool { return len(s) == 0 || (len(s) == 1 && s[0] == 1) } - -// IsVector returns whether the access pattern falls into one of three possible definitions of vectors: -// vanilla vector (not a row or a col) -// column vector -// row vector -func (s Shape) IsVector() bool { return s.IsColVec() || s.IsRowVec() || (len(s) == 1 && s[0] > 1) } - -// IsColVec returns true when the access pattern has the shape (x, 1) -func (s Shape) IsColVec() bool { return len(s) == 2 && (s[1] == 1 && s[0] > 1) } - -// IsRowVec returns true when the access pattern has the shape (1, x) -func (s Shape) IsRowVec() bool { return len(s) == 2 && (s[0] == 1 && s[1] > 1) } - -// IsMatrix returns true if it's a matrix. This is mostly a convenience method. RowVec and ColVecs are also considered matrices -func (s Shape) IsMatrix() bool { return len(s) == 2 } - -// Dims returns the number of dimensions in the shape -func (s Shape) Dims() int { return len(s) } - -func (s Shape) DimSize(d int) (size int, err error) { - if (s.IsScalar() && d != 0) || (!s.IsScalar() && d >= len(s)) { - err = errors.Errorf(dimMismatch, len(s), d) - return - } - - switch { - case s.IsScalar(): - return 0, nil - default: - return s[d], nil - } -} - -// S gives the new shape after a shape has been sliced. It's repeated from the AP S() method mainly because there are other functions in Gorgonia that uses only shape -func (s Shape) S(slices ...Slice) (retVal Shape, err error) { - opDims := len(s) - if len(slices) > opDims { - err = errors.Errorf(dimMismatch, opDims, len(slices)) - return - } - - retVal = s.Clone() - - for d, size := range s { - var sl Slice // default is a nil Slice - if d <= len(slices)-1 { - sl = slices[d] - } - - var start, end, step int - if start, end, step, err = SliceDetails(sl, size); err != nil { - return - } - - if step > 0 { - retVal[d] = (end - start) / step - - //fix - if retVal[d] <= 0 { - retVal[d] = 1 - } - } else { - retVal[d] = (end - start) - } - - } - - // drop any dimension with size 1, except the last dimension - dims := s.Dims() - for d := 0; d < dims; d++ { - if retVal[d] == 1 /*&& d != t.dims-1 && dims > 2*/ { - retVal = append(retVal[:d], retVal[d+1:]...) - d-- - dims-- - } - } - - if retVal.IsScalar() { - ReturnInts(retVal) - return ScalarShape(), nil - } - - return -} - -// Repeat returns the expected new shape given the repetition parameters. -func (s Shape) Repeat(axis int, repeats ...int) (newShape Shape, finalRepeats []int, size int, err error) { +// asMat((5), 0, true) = (1, 5) +// asMat((5), 1, true) = (5, 1) +// asMat((3,4,5), 0, true) = (1, 60) +// asMat((3,4,5), 1, true) = (3, 20) +// asMat((3,4,5), 2, true) = (12, 5) +// asMat((3,4,5), 0, false) = (1, 20) +// asMat((3,4,5), 1, false) = (3, 5) +// asMat((3,4,5), 2, false) = (12, 1) +func asMat(a Shape, axis int, inclusive bool) (retVal Shape) { + // no need to do a check because asMat will only ever be used by internal functions. + + retVal = Shape(BorrowInts(2)) switch { - case axis == AllAxes: - size = s.TotalSize() - newShape = Shape{size} - axis = 0 - case s.IsScalar(): - size = 1 - // special case for row vecs - if axis == 1 { - newShape = Shape{1, 0} - } else { - // otherwise it will be repeated into a vanilla vector - newShape = Shape{0} - } - case s.IsVector() && !s.IsRowVec() && !s.IsColVec() && axis == 1: - size = 1 - newShape = s.Clone() - newShape = append(newShape, 1) - default: - if axis >= len(s) { - // error - err = errors.Errorf(invalidAxis, axis, s.Dims()) - return - } - size = s[axis] - newShape = s.Clone() - } - - // special case to allow generic repeats - if len(repeats) == 1 { - rep := repeats[0] - repeats = make([]int, size) - for i := range repeats { - repeats[i] = rep - } - } - reps := len(repeats) - if reps != size { - err = errors.Errorf(broadcastError, size, reps) - return - } - - newSize := SumInts(repeats) - newShape[axis] = newSize - finalRepeats = repeats - return -} - -// Concat returns the expected new shape given the concatenation parameters -func (s Shape) Concat(axis int, ss ...Shape) (newShape Shape, err error) { - dims := s.Dims() - - // check that all the concatenates have the same dimensions - for _, shp := range ss { - if shp.Dims() != dims { - err = errors.Errorf(dimMismatch, dims, shp.Dims()) - return - } - } - - // special case - if axis == AllAxes { - axis = 0 - } - - // nope... no negative indexing here. - if axis < 0 { - err = errors.Errorf(invalidAxis, axis, len(s)) + case a.Dims() == 1 && axis == 0: + retVal[0] = 1 + retVal[1] = a[0] return - } - - if axis >= dims { - err = errors.Errorf(invalidAxis, axis, len(s)) + case a.Dims() == 1 && axis == 1: + retVal[0] = a[0] + retVal[1] = 1 return } - - newShape = Shape(BorrowInts(dims)) - copy(newShape, s) - - for _, shp := range ss { - for d := 0; d < dims; d++ { - if d == axis { - newShape[d] += shp[d] - } else { - // validate that the rest of the dimensions match up - if newShape[d] != shp[d] { - err = errors.Errorf(dimMismatch, newShape[d], shp[d]) - return - } - } - } + // outer + retVal[0] = ProdInts(a[:axis]) + aplus := axis + if !inclusive { + aplus++ } + // inner + retVal[1] = ProdInts(a[aplus:]) return } - -// Format implements fmt.Formatter, and formats a shape nicely -func (s Shape) Format(st fmt.State, r rune) { - switch r { - case 'v', 's': - st.Write([]byte("(")) - for i, v := range s { - fmt.Fprintf(st, "%d", v) - if i < len(s)-1 { - st.Write([]byte(", ")) - } - } - st.Write([]byte(")")) - default: - fmt.Fprintf(st, "%v", []int(s)) - } -} diff --git a/shape_test.go b/shape_test.go index 51fe64a..9433ba9 100644 --- a/shape_test.go +++ b/shape_test.go @@ -1,295 +1,47 @@ package tensor import ( - "fmt" "testing" "github.com/stretchr/testify/assert" ) -func TestShapeBasics(t *testing.T) { - var s Shape - var ds int - var err error - s = Shape{1, 2} - - if ds, err = s.DimSize(0); err != nil { - t.Error(err) - } - if ds != 1 { - t.Error("Expected DimSize(0) to be 1") - } - - if ds, err = s.DimSize(2); err == nil { - t.Error("Expected a DimensionMismatch error") - } - - s = ScalarShape() - if ds, err = s.DimSize(0); err != nil { - t.Error(err) - } - - if ds != 0 { - t.Error("Expected DimSize(0) of a scalar to be 0") - } - - // format for completeness sake - s = Shape{2, 1} - if fmt.Sprintf("%d", s) != "[2 1]" { - t.Error("Shape.Format() error") - } -} - -func TestShapeIsX(t *testing.T) { - assert := assert.New(t) - var s Shape - - // scalar shape - s = Shape{} - assert.True(s.IsScalar()) - assert.False(s.IsVector()) - assert.False(s.IsColVec()) - assert.False(s.IsRowVec()) - - s = Shape{1} - assert.True(s.IsScalar()) - assert.False(s.IsVector()) - assert.False(s.IsColVec()) - assert.False(s.IsRowVec()) - - // vector - s = Shape{2} - assert.False(s.IsScalar()) - assert.True(s.IsVector()) - assert.False(s.IsColVec()) - assert.False(s.IsRowVec()) - - s = Shape{2, 1} - assert.False(s.IsScalar()) - assert.True(s.IsVector()) - assert.True(s.IsColVec()) - assert.False(s.IsRowVec()) - - s = Shape{1, 2} - assert.False(s.IsScalar()) - assert.True(s.IsVector()) - assert.False(s.IsColVec()) - assert.True(s.IsRowVec()) - - // matrix and up - s = Shape{2, 2} - assert.False(s.IsScalar()) - assert.False(s.IsVector()) - assert.False(s.IsColVec()) - assert.False(s.IsRowVec()) -} - func TestShapeCalcStride(t *testing.T) { assert := assert.New(t) var s Shape // scalar shape s = Shape{} - assert.Nil(s.calcStrides()) + assert.Nil(CalcStrides(s)) + // vector shape s = Shape{1} - assert.Nil(s.calcStrides()) + assert.Equal([]int{1}, CalcStrides(s)) - // vector shape s = Shape{2, 1} - assert.Equal([]int{1}, s.calcStrides()) + assert.Equal([]int{1, 1}, CalcStrides(s)) s = Shape{1, 2} - assert.Equal([]int{1}, s.calcStrides()) + assert.Equal([]int{2, 1}, CalcStrides(s)) s = Shape{2} - assert.Equal([]int{1}, s.calcStrides()) + assert.Equal([]int{1}, CalcStrides(s)) // matrix strides s = Shape{2, 2} - assert.Equal([]int{2, 1}, s.calcStrides()) + assert.Equal([]int{2, 1}, CalcStrides(s)) s = Shape{5, 2} - assert.Equal([]int{2, 1}, s.calcStrides()) + assert.Equal([]int{2, 1}, CalcStrides(s)) // 3D strides s = Shape{2, 3, 4} - assert.Equal([]int{12, 4, 1}, s.calcStrides()) + assert.Equal([]int{12, 4, 1}, CalcStrides(s)) // stupid shape s = Shape{-2, 1, 2} fail := func() { - s.calcStrides() + CalcStrides(s) } assert.Panics(fail) } - -func TestShapeEquality(t *testing.T) { - assert := assert.New(t) - var s1, s2 Shape - - // scalar - s1 = Shape{1} - s2 = Shape{} - assert.True(s1.Eq(s2)) - assert.True(s2.Eq(s1)) - - // vector - s1 = Shape{3} - s2 = Shape{5} - assert.False(s1.Eq(s2)) - assert.False(s2.Eq(s1)) - - s1 = Shape{2, 1} - s2 = Shape{2, 1} - assert.True(s1.Eq(s2)) - assert.True(s2.Eq(s1)) - - s2 = Shape{2} - assert.True(s1.Eq(s2)) - assert.True(s2.Eq(s1)) - - s2 = Shape{1, 2} - assert.False(s1.Eq(s2)) - assert.False(s2.Eq(s1)) - - s1 = Shape{2} - assert.True(s1.Eq(s2)) - assert.True(s2.Eq(s1)) - - s2 = Shape{2, 3} - assert.False(s1.Eq(s2)) - assert.False(s2.Eq(s1)) - - // matrix - s1 = Shape{2, 3} - assert.True(s1.Eq(s2)) - assert.True(s2.Eq(s1)) - - s2 = Shape{3, 2} - assert.False(s1.Eq(s2)) - assert.False(s2.Eq(s1)) - - // just for that green coloured code - s1 = Shape{2} - s2 = Shape{1, 3} - assert.False(s1.Eq(s2)) - assert.False(s2.Eq(s1)) -} - -var shapeSliceTests = []struct { - name string - s Shape - sli []Slice - - expected Shape - err bool -}{ - {"slicing a scalar shape", ScalarShape(), nil, ScalarShape(), false}, - {"slicing a scalar shape", ScalarShape(), []Slice{rs{0, 0, 0}}, nil, true}, - {"vec[0]", Shape{2}, []Slice{rs{0, 1, 0}}, ScalarShape(), false}, - {"vec[3]", Shape{2}, []Slice{rs{3, 4, 0}}, nil, true}, - {"vec[:, 0]", Shape{2}, []Slice{nil, rs{0, 1, 0}}, nil, true}, - {"vec[1:4:2]", Shape{5}, []Slice{rs{1, 4, 2}}, ScalarShape(), false}, -} - -func TestShape_Slice(t *testing.T) { - for i, ssts := range shapeSliceTests { - newShape, err := ssts.s.S(ssts.sli...) - if checkErr(t, ssts.err, err, "Shape slice", i) { - continue - } - - if !ssts.expected.Eq(newShape) { - t.Errorf("Test %q: Expected shape %v. Got %v instead", ssts.name, ssts.expected, newShape) - } - } -} - -var shapeRepeatTests = []struct { - name string - s Shape - repeats []int - axis int - - expected Shape - expectedRepeats []int - expectedSize int - err bool -}{ - {"scalar repeat on axis 0", ScalarShape(), []int{3}, 0, Shape{3}, []int{3}, 1, false}, - {"scalar repeat on axis 1", ScalarShape(), []int{3}, 1, Shape{1, 3}, []int{3}, 1, false}, - {"vector repeat on axis 0", Shape{2}, []int{3}, 0, Shape{6}, []int{3, 3}, 2, false}, - {"vector repeat on axis 1", Shape{2}, []int{3}, 1, Shape{2, 3}, []int{3}, 1, false}, - {"colvec repeats on axis 0", Shape{2, 1}, []int{3}, 0, Shape{6, 1}, []int{3, 3}, 2, false}, - {"colvec repeats on axis 1", Shape{2, 1}, []int{3}, 1, Shape{2, 3}, []int{3}, 1, false}, - {"rowvec repeats on axis 0", Shape{1, 2}, []int{3}, 0, Shape{3, 2}, []int{3}, 1, false}, - {"rowvec repeats on axis 1", Shape{1, 2}, []int{3}, 1, Shape{1, 6}, []int{3, 3}, 2, false}, - {"3-Tensor repeats", Shape{2, 3, 2}, []int{1, 2, 1}, 1, Shape{2, 4, 2}, []int{1, 2, 1}, 3, false}, - {"3-Tensor generic repeats", Shape{2, 3, 2}, []int{2}, AllAxes, Shape{24}, []int{2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2}, 12, false}, - {"3-Tensor generic repeat, axis specified", Shape{2, 3, 2}, []int{2}, 2, Shape{2, 3, 4}, []int{2, 2}, 2, false}, - - // stupids - {"nonexisting axis 2", Shape{2, 1}, []int{3}, 2, nil, nil, 0, true}, - {"mismatching repeats", Shape{2, 3, 2}, []int{3, 1, 2}, 0, nil, nil, 0, true}, -} - -func TestShape_Repeat(t *testing.T) { - assert := assert.New(t) - for _, srts := range shapeRepeatTests { - newShape, reps, size, err := srts.s.Repeat(srts.axis, srts.repeats...) - - switch { - case srts.err: - if err == nil { - t.Error("Expected an error") - } - continue - case !srts.err && err != nil: - t.Error(err) - continue - } - - assert.True(srts.expected.Eq(newShape), "Test %q: Want: %v. Got %v", srts.name, srts.expected, newShape) - assert.Equal(srts.expectedRepeats, reps, "Test %q: ", srts.name) - assert.Equal(srts.expectedSize, size, "Test %q: ", srts.name) - } -} - -var shapeConcatTests = []struct { - name string - s Shape - axis int - ss []Shape - - expected Shape - err bool -}{ - {"standard, axis 0 ", Shape{2, 2}, 0, []Shape{{2, 2}, {2, 2}}, Shape{6, 2}, false}, - {"standard, axis 1 ", Shape{2, 2}, 1, []Shape{{2, 2}, {2, 2}}, Shape{2, 6}, false}, - {"standard, axis AllAxes ", Shape{2, 2}, -1, []Shape{{2, 2}, {2, 2}}, Shape{6, 2}, false}, - {"concat to empty", Shape{2}, 0, nil, Shape{2}, false}, - - {"stupids: different dims", Shape{2, 2}, 0, []Shape{{2, 3, 2}}, nil, true}, - {"stupids: negative axes", Shape{2, 2}, -5, []Shape{{2, 2}}, nil, true}, - {"stupids: toobig axis", Shape{2, 2}, 5, []Shape{{2, 2}}, nil, true}, - {"subtle stupids: dim mismatch", Shape{2, 2}, 0, []Shape{{2, 2}, {2, 3}}, nil, true}, -} - -func TestShape_Concat(t *testing.T) { - assert := assert.New(t) - for _, scts := range shapeConcatTests { - newShape, err := scts.s.Concat(scts.axis, scts.ss...) - switch { - case scts.err: - if err == nil { - t.Error("Expected an error") - } - continue - case !scts.err && err != nil: - t.Error(err) - continue - } - assert.Equal(scts.expected, newShape) - } -} diff --git a/slice.go b/slice.go index ecba60d..7ee3522 100644 --- a/slice.go +++ b/slice.go @@ -1,11 +1,13 @@ package tensor -// A Slice represents a slicing operation for a Tensor. -type Slice interface { - Start() int - End() int - Step() int -} +import ( + "gorgonia.org/shapes" +) + +var xxx Slice = ss(1) +var _ shapes.Slice = xxx + +type Slice = shapes.Slice type rs struct { start, end, step int @@ -34,3 +36,39 @@ type ss int func (s ss) Start() int { return int(s) } func (s ss) End() int { return int(s) + 1 } func (s ss) Step() int { return 0 } + +// sli is slice. It's named sli to prevent confusion over naming +type sli struct { + start, end, step int +} + +// S creates a Slice. +// end is optional. It should be passed in as the first param of the optionals. +// step is optional. It should be passed in as the second param of the optionals. +// +// Default end is start+1. Default step is 1, unless end == step+1, then it defaults to 0 +func S(start int, opt ...int) Slice { + var end, step int + if len(opt) > 0 { + end = opt[0] + } else { + end = start + 1 + } + + step = 1 + if len(opt) > 1 { + step = opt[1] + } else if end == start+1 { + step = 0 + } + + return &sli{ + start: start, + end: end, + step: step, + } +} + +func (s *sli) Start() int { return s.start } +func (s *sli) End() int { return s.end } +func (s *sli) Step() int { return s.step } diff --git a/sparse.go b/sparse.go index abb36c1..b500db8 100644 --- a/sparse.go +++ b/sparse.go @@ -2,11 +2,11 @@ package tensor import ( "reflect" - "unsafe" "sort" "github.com/pkg/errors" + "gorgonia.org/dtype" ) var ( @@ -29,9 +29,9 @@ type coo struct { data array } -func (c *coo) Len() int { return c.data.L } +func (c *coo) Len() int { return c.data.Len() } func (c *coo) Less(i, j int) bool { - if c.o.isColMajor() { + if c.o.IsColMajor() { return c.colMajorLess(i, j) } return c.rowMajorLess(i, j) @@ -182,13 +182,14 @@ func CSCFromCoord(shape Shape, xs, ys []int, data interface{}) *CS { return t } -func (t *CS) Shape() Shape { return t.s } -func (t *CS) Strides() []int { return nil } -func (t *CS) Dtype() Dtype { return t.t } -func (t *CS) Dims() int { return 2 } -func (t *CS) Size() int { return t.s.TotalSize() } -func (t *CS) DataSize() int { return t.L } -func (t *CS) Engine() Engine { return t.e } +func (t *CS) Shape() Shape { return t.s } +func (t *CS) Strides() []int { return nil } +func (t *CS) Dtype() dtype.Dtype { return t.t } +func (t *CS) Dims() int { return 2 } +func (t *CS) Size() int { return t.s.TotalSize() } +func (t *CS) DataSize() int { return t.Len() } +func (t *CS) Engine() Engine { return t.e } +func (t *CS) DataOrder() DataOrder { return t.o } func (t *CS) Slice(...Slice) (View, error) { return nil, errors.Errorf("Slice for sparse tensors not implemented yet") @@ -232,18 +233,17 @@ func (t *CS) T(axes ...int) error { } UnsafePermute(axes, []int(t.s)) t.o = t.o.toggleColMajor() - return errors.Errorf(methodNYI, "T") + t.o = MakeDataOrder(t.o, Transposed) + return nyierr(typeNYI, t) } // UT untransposes the CS -func (t *CS) UT() { t.T() } +func (t *CS) UT() { t.T(); t.o = t.o.clearTransposed() } // Transpose is a no-op. The data does not move func (t *CS) Transpose() error { return nil } -func (t *CS) Apply(fn interface{}, opts ...FuncOpt) (Tensor, error) { - return nil, errors.Errorf(methodNYI, "Apply") -} +func (t *CS) Apply(fn interface{}, opts ...FuncOpt) (Tensor, error) { return nil, nyierr(typeNYI, t) } func (t *CS) Eq(other interface{}) bool { if ot, ok := other.(*CS); ok { @@ -287,7 +287,7 @@ func (t *CS) Clone() interface{} { retVal.indptr = make([]int, len(t.indptr)) copy(retVal.indices, t.indices) copy(retVal.indptr, t.indptr) - retVal.array = makeArray(t.t, t.array.L) + retVal.array = makeArray(t.t, t.array.Len()) copyArray(&retVal.array, &t.array) retVal.e = t.e return retVal @@ -296,18 +296,17 @@ func (t *CS) Clone() interface{} { func (t *CS) IsScalar() bool { return false } func (t *CS) ScalarValue() interface{} { panic("Sparse Matrices cannot represent Scalar Values") } -func (t *CS) MemSize() uintptr { return uintptr(calcMemSize(t.t, t.array.L)) } -func (t *CS) Uintptr() uintptr { return uintptr(t.array.Ptr) } -func (t *CS) Pointer() unsafe.Pointer { return t.array.Ptr } +func (t *CS) MemSize() uintptr { return uintptr(calcMemSize(t.t, t.array.Len())) } +func (t *CS) Uintptr() uintptr { return t.array.Uintptr() } // NonZeroes returns the nonzeroes. In academic literature this is often written as NNZ. -func (t *CS) NonZeroes() int { return t.L } +func (t *CS) NonZeroes() int { return t.Len() } func (t *CS) RequiresIterator() bool { return true } func (t *CS) Iterator() Iterator { return NewFlatSparseIterator(t) } func (t *CS) at(coord ...int) (int, bool) { var r, c int - if t.o.isColMajor() { + if t.o.IsColMajor() { r = coord[1] c = coord[0] } else { @@ -329,8 +328,8 @@ func (t *CS) Dense() *Dense { // use } - d := recycledDense(t.t, t.Shape().Clone()) - if t.o.isColMajor() { + d := recycledDense(t.t, t.Shape().Clone(), WithEngine(t.e)) + if t.o.IsColMajor() { for i := 0; i < len(t.indptr)-1; i++ { for j := t.indptr[i]; j < t.indptr[i+1]; j++ { d.SetAt(t.Get(j), t.indices[j], i) @@ -361,14 +360,14 @@ func (t *CS) Indices() []int { } func (t *CS) AsCSR() { - if t.o.isRowMajor() { + if t.o.IsRowMajor() { return } t.o.toggleColMajor() } func (t *CS) AsCSC() { - if t.o.isColMajor() { + if t.o.IsColMajor() { return } t.o.toggleColMajor() @@ -379,4 +378,4 @@ func (t *CS) IsManuallyManaged() bool { return t.f.manuallyManaged() } func (t *CS) arr() array { return t.array } func (t *CS) arrPtr() *array { return &t.array } -func (t *CS) standardEngine() standardEngine { return nil } +func (t *CS) standardEngine() StandardEngine { return nil } diff --git a/sparse_test.go b/sparse_test.go index 86cdad1..34b22dd 100644 --- a/sparse_test.go +++ b/sparse_test.go @@ -1,105 +1,105 @@ -package tensor - -import ( - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestCS_Basics(t *testing.T) { - assert := assert.New(t) - xs0 := []int{1, 2, 6, 8} - ys0 := []int{1, 2, 1, 6} - xs1 := []int{1, 2, 6, 8} - ys1 := []int{1, 2, 1, 6} - vals0 := []float64{3, 1, 4, 1} - vals1 := []float64{3, 1, 4, 1} - - var T0, T1 *CS - var d0, d1 *Dense - var dp0, dp1 *Dense - var err error - fails := func() { - CSCFromCoord(Shape{7, 6}, xs0, ys0, vals0) - } - assert.Panics(fails) - - // Test CSC - T0 = CSCFromCoord(Shape{9, 7}, xs0, ys0, vals0) - d0 = T0.Dense() - T0.T() - dp0 = T0.Dense() - T0.UT() // untranspose as Materialize() will be called below - - // Test CSR - fails = func() { - CSRFromCoord(Shape{7, 6}, xs1, ys1, vals1) - } - T1 = CSRFromCoord(Shape{9, 7}, xs1, ys1, vals1) - d1 = T1.Dense() - T1.T() - dp1 = T1.Dense() - T1.UT() - - t.Logf("%v %v", T0.indptr, T0.indices) - t.Logf("%v %v", T1.indptr, T1.indices) - - assert.True(d0.Eq(d1), "%+#v\n %+#v\n", d0, d1) - assert.True(dp0.Eq(dp1)) - assert.True(T1.Eq(T1)) - assert.False(T0.Eq(T1)) - - // At - var got interface{} - correct := float64(3.0) - if got, err = T0.At(1, 1); err != nil { - t.Error(err) - } - if got.(float64) != correct { - t.Errorf("Expected %v. Got %v - T0[1,1]", correct, got) - } - if got, err = T1.At(1, 1); err != nil { - t.Error(err) - } - if got.(float64) != correct { - t.Errorf("Expected %v. Got %v - T1[1,1]", correct, got) - } - - correct = 0.0 - if got, err = T0.At(3, 3); err != nil { - t.Error(err) - } - if got.(float64) != correct { - t.Errorf("Expected %v. Got %v - T0[3,3]", correct, got) - } - - if got, err = T1.At(3, 3); err != nil { - t.Error(err) - } - if got.(float64) != correct { - t.Errorf("Expected %v. Got %v - T1[3,3]", correct, got) - } - - // Test clone - T2 := T0.Clone() - assert.True(T0.Eq(T2)) - - // Scalar representation - assert.False(T0.IsScalar()) - fails = func() { - T0.ScalarValue() - } - assert.Panics(fails) - assert.Equal(len(vals0), T0.NonZeroes()) - - // Sparse Iterator - it := T0.Iterator() - var valids []int - correctValids := []int{0, 2, 1, 3} - for i, valid, err := it.NextValidity(); err == nil; i, valid, err = it.NextValidity() { - if valid { - valids = append(valids, i) - } - } - assert.Equal(correctValids, valids) -} +package tensor + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestCS_Basics(t *testing.T) { + assert := assert.New(t) + xs0 := []int{1, 2, 6, 8} + ys0 := []int{1, 2, 1, 6} + xs1 := []int{1, 2, 6, 8} + ys1 := []int{1, 2, 1, 6} + vals0 := []float64{3, 1, 4, 1} + vals1 := []float64{3, 1, 4, 1} + + var T0, T1 *CS + var d0, d1 *Dense + var dp0, dp1 *Dense + var err error + fails := func() { + CSCFromCoord(Shape{7, 6}, xs0, ys0, vals0) + } + assert.Panics(fails) + + // Test CSC + T0 = CSCFromCoord(Shape{9, 7}, xs0, ys0, vals0) + d0 = T0.Dense() + T0.T() + dp0 = T0.Dense() + T0.UT() // untranspose as Materialize() will be called below + + // Test CSR + fails = func() { + CSRFromCoord(Shape{7, 6}, xs1, ys1, vals1) + } + T1 = CSRFromCoord(Shape{9, 7}, xs1, ys1, vals1) + d1 = T1.Dense() + T1.T() + dp1 = T1.Dense() + T1.UT() + + t.Logf("%v %v", T0.indptr, T0.indices) + t.Logf("%v %v", T1.indptr, T1.indices) + + assert.True(d0.Eq(d1), "%+#v\n %+#v\n", d0, d1) + assert.True(dp0.Eq(dp1)) + assert.True(T1.Eq(T1)) + assert.False(T0.Eq(T1)) + + // At + var got interface{} + correct := float64(3.0) + if got, err = T0.At(1, 1); err != nil { + t.Error(err) + } + if got.(float64) != correct { + t.Errorf("Expected %v. Got %v - T0[1,1]", correct, got) + } + if got, err = T1.At(1, 1); err != nil { + t.Error(err) + } + if got.(float64) != correct { + t.Errorf("Expected %v. Got %v - T1[1,1]", correct, got) + } + + correct = 0.0 + if got, err = T0.At(3, 3); err != nil { + t.Error(err) + } + if got.(float64) != correct { + t.Errorf("Expected %v. Got %v - T0[3,3]", correct, got) + } + + if got, err = T1.At(3, 3); err != nil { + t.Error(err) + } + if got.(float64) != correct { + t.Errorf("Expected %v. Got %v - T1[3,3]", correct, got) + } + + // Test clone + T2 := T0.Clone() + assert.True(T0.Eq(T2)) + + // Scalar representation + assert.False(T0.IsScalar()) + fails = func() { + T0.ScalarValue() + } + assert.Panics(fails) + assert.Equal(len(vals0), T0.NonZeroes()) + + // Sparse Iterator + it := T0.Iterator() + var valids []int + correctValids := []int{0, 2, 1, 3} + for i, valid, err := it.NextValidity(); err == nil; i, valid, err = it.NextValidity() { + if valid { + valids = append(valids, i) + } + } + assert.Equal(correctValids, valids) +} diff --git a/tensor.go b/tensor.go index d1b348a..8445a39 100644 --- a/tensor.go +++ b/tensor.go @@ -4,17 +4,15 @@ package tensor // import "gorgonia.org/tensor" import ( "encoding/gob" - "fmt" - "io" - "unsafe" "github.com/pkg/errors" + "gorgonia.org/dtype" ) var ( _ Tensor = &Dense{} _ Tensor = &CS{} - _ View = &Dense{} + _ View = &DenseView{} ) func init() { @@ -22,20 +20,27 @@ func init() { gob.Register(&CS{}) } -// Tensor represents a variety of n-dimensional arrays. The most commonly used tensor is the Dense tensor. -// It can be used to represent a vector, matrix, 3D matrix and n-dimensional tensors. -type Tensor interface { +// Desc is a description of a tensor. It does not actually deal with data. +type Desc interface { // info about the ndarray Shape() Shape Strides() []int - Dtype() Dtype + Dtype() dtype.Dtype + Dims() int Size() int DataSize() int +} + +// Tensor represents a variety of n-dimensional arrays. The most commonly used tensor is the Dense tensor. +// It can be used to represent a vector, matrix, 3D matrix and n-dimensional tensors. +type Tensor interface { + Desc // Data access related RequiresIterator() bool Iterator() Iterator + DataOrder() DataOrder // ops Slicer @@ -54,39 +59,34 @@ type Tensor interface { Eq Cloner - // type overloading methods - IsScalar() bool - ScalarValue() interface{} - // engine/memory related stuff // all Tensors should be able to be expressed of as a slab of memory // Note: the size of each element can be acquired by T.Dtype().Size() + Memory // Tensors all implement Memory Engine() Engine // Engine can be nil - MemSize() uintptr // the size in memory - Uintptr() uintptr // the pointer to the first element, as a uintptr - Pointer() unsafe.Pointer // the pointer to the first elemment as a unsafe.Ponter IsNativelyAccessible() bool // Can Go access the memory IsManuallyManaged() bool // Must Go manage the memory // formatters - fmt.Formatter - fmt.Stringer + // fmt.Formatter + // fmt.Stringer // all Tensors are serializable to these formats - WriteNpy(io.Writer) error - ReadNpy(io.Reader) error - gob.GobEncoder - gob.GobDecoder + //WriteNpy(io.Writer) error + //ReadNpy(io.Reader) error + //gob.GobEncoder + //gob.GobDecoder - standardEngine() standardEngine headerer arrayer + + // TO BE DEPRECATED + ScalarRep } // New creates a new Dense Tensor. For sparse arrays use their relevant construction function func New(opts ...ConsOpt) *Dense { d := borrowDense() - d.AP = new(AP) for _, opt := range opts { opt(d) } @@ -98,12 +98,26 @@ func New(opts ...ConsOpt) *Dense { return d } +// MustGetDense gets a *Dense from a given Tensor. Panics otherwise. +func MustGetDense(T Tensor) *Dense { + d, err := assertDense(T) + if err != nil { + panic(err) + } + return d +} + func assertDense(t Tensor) (*Dense, error) { if t == nil { return nil, errors.New("nil is not a *Dense") } - if retVal, ok := t.(*Dense); ok { - return retVal, nil + switch tt := t.(type) { + case *Dense: + return tt, nil + case DenseView: + return tt.Dense, nil + case Densor: + return tt.Dense(), nil } return nil, errors.Errorf("%T is not *Dense", t) } @@ -124,7 +138,7 @@ func getFloatDenseTensor(t Tensor) (retVal DenseTensor, err error) { if t == nil { return } - if err = typeclassCheck(t.Dtype(), floatTypes); err != nil { + if err = dtype.TypeClassCheck(t.Dtype(), dtype.Floats); err != nil { err = errors.Wrapf(err, "getFloatDense only handles floats. Got %v instead", t.Dtype()) return } @@ -140,10 +154,32 @@ func getFloatDenseTensor(t Tensor) (retVal DenseTensor, err error) { return } +// getFloatDense extracts a *Dense from a Tensor and ensures that the .data is a Array that implements Float +func getFloatComplexDenseTensor(t Tensor) (retVal DenseTensor, err error) { + if t == nil { + return + } + if err = dtype.TypeClassCheck(t.Dtype(), dtype.FloatComplex); err != nil { + err = errors.Wrapf(err, "getFloatDense only handles floats and complex. Got %v instead", t.Dtype()) + return + } + + if retVal, err = getDenseTensor(t); err != nil { + err = errors.Wrapf(err, opFail, "getFloatDense") + return + } + if retVal == nil { + return + } + + return +} + +// sliceDense returns a *Dense. func sliceDense(t *Dense, slices ...Slice) (retVal *Dense, err error) { var sliced Tensor if sliced, err = t.Slice(slices...); err != nil { return nil, err } - return sliced.(*Dense), nil + return sliced.(DenseView).Dense, nil } diff --git a/test_test.go b/test_test.go index 5f76d8a..f5a7e0c 100644 --- a/test_test.go +++ b/test_test.go @@ -1,5 +1,3 @@ -// Code generated by genlib2. DO NOT EDIT. - package tensor import ( @@ -9,8 +7,11 @@ import ( "unsafe" "github.com/chewxy/math32" + "gorgonia.org/dtype" ) +// Code generated by genlib2. DO NOT EDIT. + func anyToFloat64s(x interface{}) (retVal []float64) { switch xt := x.(type) { case []int: @@ -120,7 +121,7 @@ func anyToFloat64s(x interface{}) (retVal []float64) { panic("Unreachable") } -func identityVal(x int, dt Dtype) interface{} { +func identityVal(x int, dt dtype.Dtype) interface{} { switch dt { case Int: return int(x) diff --git a/testutils_test.go b/testutils_test.go index e219ab1..77312fb 100644 --- a/testutils_test.go +++ b/testutils_test.go @@ -1,6 +1,8 @@ package tensor import ( + "bytes" + "context" "errors" "math" "math/cmplx" @@ -13,6 +15,8 @@ import ( "github.com/chewxy/math32" "gorgonia.org/tensor/internal/storage" + + "gorgonia.org/dtype" ) func randomBool() bool { @@ -240,6 +244,7 @@ func allClose(a, b interface{}, approxFn ...interface{}) bool { return reflect.DeepEqual(a, b) } } + func checkErr(t *testing.T, expected bool, err error, name string, id interface{}) (cont bool) { switch { case expected: @@ -328,7 +333,7 @@ func shuffleInts(a []int, r *rand.Rand) { type TensorGenerator struct { ShapeConstraint Shape - DtypeConstraint Dtype + DtypeConstraint dtype.Dtype } func (g TensorGenerator) Generate(r *rand.Rand, size int) reflect.Value { @@ -340,8 +345,8 @@ func (g TensorGenerator) Generate(r *rand.Rand, size int) reflect.Value { func (t *Dense) Generate(r *rand.Rand, size int) reflect.Value { // generate type - ri := r.Intn(len(specializedTypes.set)) - of := specializedTypes.set[ri] + ri := r.Intn(len(specializedTypes)) + of := specializedTypes[ri] datatyp := reflect.SliceOf(of.Type) gendat, _ := quick.Value(datatyp, r) // generate dims @@ -475,17 +480,10 @@ func (e dummyEngine) Memclr(mem Memory) {} func (e dummyEngine) Memcpy(dst, src Memory) error { if e { var a, b storage.Header - a.Ptr = src.Pointer() - a.L = int(src.MemSize()) - a.C = int(src.MemSize()) - - b.Ptr = dst.Pointer() - b.L = int(dst.MemSize()) - b.C = int(dst.MemSize()) + a.Raw = storage.FromMemory(src.Uintptr(), src.MemSize()) + b.Raw = storage.FromMemory(dst.Uintptr(), dst.MemSize()) - abs := *(*[]byte)(unsafe.Pointer(&a)) - bbs := *(*[]byte)(unsafe.Pointer(&b)) - copy(bbs, abs) + copy(b.Raw, a.Raw) return nil } return errors.New("Unable to copy ") @@ -507,14 +505,26 @@ func (e dummyEngine2) Memcpy(dst, src Memory) error { return e.e.Mem func (e dummyEngine2) Accessible(mem Memory) (Memory, error) { return e.e.Accessible(mem) } func (e dummyEngine2) WorksWith(order DataOrder) bool { return e.e.WorksWith(order) } -func (e dummyEngine2) Argmax(t Tensor, axis int) (Tensor, error) { return e.e.Argmax(t, axis) } -func (e dummyEngine2) Argmin(t Tensor, axis int) (Tensor, error) { return e.e.Argmin(t, axis) } +func (e dummyEngine2) Argmax(ctx context.Context, t Tensor, axis int) (Tensor, error) { + return e.e.Argmax(ctx, t, axis) +} +func (e dummyEngine2) Argmin(ctx context.Context, t Tensor, axis int) (Tensor, error) { + return e.e.Argmin(ctx, t, axis) +} -func willerr(a *Dense, tc, eqtc *typeclass) (retVal, willFailEq bool) { - if err := typeclassCheck(a.Dtype(), eqtc); err == nil { +func willerr(a *Dense, tc, eqtc dtype.TypeClass) (retVal, willFailEq bool) { + if eqtc == nilTC { willFailEq = true + } else { + if err := dtype.TypeClassCheck(a.Dtype(), eqtc); err == nil { + willFailEq = true + } + } + if tc == nilTC { + retVal = !a.IsNativelyAccessible() + return } - if err := typeclassCheck(a.Dtype(), tc); err != nil { + if err := dtype.TypeClassCheck(a.Dtype(), tc); err != nil { return true, willFailEq } @@ -544,14 +554,14 @@ func qcErrCheck(t *testing.T, name string, a Dtyper, b interface{}, we bool, err return nil, false } -func qcIsFloat(dt Dtype) bool { - if err := typeclassCheck(dt, floatcmplxTypes); err == nil { +func qcIsFloat(dt dtype.Dtype) bool { + if err := dtype.TypeClassCheck(dt, dtype.FloatComplex); err == nil { return true } return false } -func qcEqCheck(t *testing.T, dt Dtype, willFailEq bool, correct, got interface{}) bool { +func qcEqCheck(t *testing.T, dt dtype.Dtype, willFailEq bool, correct, got interface{}) bool { isFloatTypes := qcIsFloat(dt) if !willFailEq && (isFloatTypes && !allClose(correct, got) || (!isFloatTypes && !reflect.DeepEqual(correct, got))) { t.Errorf("q.Dtype: %v", dt) @@ -561,3 +571,12 @@ func qcEqCheck(t *testing.T, dt Dtype, willFailEq bool, correct, got interface{} } return true } + +// DummyState is a dummy fmt.State, used to debug things +type DummyState struct { + *bytes.Buffer +} + +func (d *DummyState) Width() (int, bool) { return 0, false } +func (d *DummyState) Precision() (int, bool) { return 0, false } +func (d *DummyState) Flag(c int) bool { return false } diff --git a/type_test.go b/type_test.go index d616b8f..7200f66 100644 --- a/type_test.go +++ b/type_test.go @@ -1,66 +1,13 @@ package tensor import ( - "reflect" - "testing" + "gorgonia.org/dtype" ) -type Float16 uint16 - -func TestRegisterType(t *testing.T) { - dt := Dtype{reflect.TypeOf(Float16(0))} - RegisterFloat(dt) - - if err := typeclassCheck(dt, floatTypes); err != nil { - t.Errorf("Expected %v to be in floatTypes: %v", dt, err) - } - if err := typeclassCheck(dt, numberTypes); err != nil { - t.Errorf("Expected %v to be in numberTypes: %v", dt, err) - } - if err := typeclassCheck(dt, ordTypes); err != nil { - t.Errorf("Expected %v to be in ordTypes: %v", dt, err) - } - if err := typeclassCheck(dt, eqTypes); err != nil { - t.Errorf("Expected %v to be in eqTypes: %v", dt, err) - } - +var numberTypes = []dtype.Dtype{ + Int, Int8, Int16, Int32, Int64, Uint, Uint8, Uint16, Uint32, Uint64, Float32, Float64, Complex64, Complex128, } -func TestDtypeConversions(t *testing.T) { - for k, v := range reverseNumpyDtypes { - if npdt, err := v.numpyDtype(); npdt != k { - t.Errorf("Expected %v to return numpy dtype of %q. Got %q instead", v, k, npdt) - } else if err != nil { - t.Errorf("Error: %v", err) - } - } - dt := Dtype{reflect.TypeOf(Float16(0))} - if _, err := dt.numpyDtype(); err == nil { - t.Errorf("Expected an error when passing in type unknown to np") - } - - for k, v := range numpyDtypes { - if dt, err := fromNumpyDtype(v); dt != k { - // special cases - if Int.Size() == 4 && v == "i4" && dt == Int { - continue - } - if Int.Size() == 8 && v == "i8" && dt == Int { - continue - } - - if Uint.Size() == 4 && v == "u4" && dt == Uint { - continue - } - if Uint.Size() == 8 && v == "u8" && dt == Uint { - continue - } - t.Errorf("Expected %q to return %v. Got %v instead", v, k, dt) - } else if err != nil { - t.Errorf("Error: %v", err) - } - } - if _, err := fromNumpyDtype("EDIUH"); err == nil { - t.Error("Expected error when nonsense is passed into fromNumpyDtype") - } +var specializedTypes = []dtype.Dtype{ + Bool, Int, Int8, Int16, Int32, Int64, Uint, Uint8, Uint16, Uint32, Uint64, Float32, Float64, Complex64, Complex128, String, } diff --git a/types.go b/types.go index fd8e189..0cc7ef1 100644 --- a/types.go +++ b/types.go @@ -4,100 +4,15 @@ import ( "fmt" "math" "reflect" - "unsafe" - "github.com/chewxy/hm" - "github.com/pkg/errors" + "gorgonia.org/dtype" ) -// Dtype represents a data type of a Tensor. Concretely it's implemented as an embedded reflect.Type -// which allows for easy reflection operations. It also implements hm.Type, for type inference in Gorgonia -type Dtype struct { - reflect.Type -} - -// note: the Name() and String() methods are already defined in reflect.Type. Might as well use the composed methods - -func (dt Dtype) Apply(hm.Subs) hm.Substitutable { return dt } -func (dt Dtype) FreeTypeVar() hm.TypeVarSet { return nil } -func (dt Dtype) Normalize(k, v hm.TypeVarSet) (hm.Type, error) { return dt, nil } -func (dt Dtype) Types() hm.Types { return nil } -func (dt Dtype) Format(s fmt.State, c rune) { fmt.Fprintf(s, "%s", dt.Name()) } -func (dt Dtype) Eq(other hm.Type) bool { return other == dt } - -var numpyDtypes map[Dtype]string -var reverseNumpyDtypes map[string]Dtype - -func init() { - numpyDtypes = map[Dtype]string{ - Bool: "b1", - Int: fmt.Sprintf("i%d", Int.Size()), - Int8: "i1", - Int16: "i2", - Int32: "i4", - Int64: "i8", - Uint: fmt.Sprintf("u%d", Uint.Size()), - Uint8: "u1", - Uint16: "u2", - Uint32: "u4", - Uint64: "u8", - Float32: "f4", - Float64: "f8", - Complex64: "c8", - Complex128: "c16", - } - - reverseNumpyDtypes = map[string]Dtype{ - "b1": Bool, - "i1": Int8, - "i2": Int16, - "i4": Int32, - "i8": Int64, - "u1": Uint8, - "u2": Uint16, - "u4": Uint32, - "u8": Uint64, - "f4": Float32, - "f8": Float64, - "c8": Complex64, - "c16": Complex128, - } -} - -// NumpyDtype returns the Numpy's Dtype equivalent. This is predominantly used in converting a Tensor to a Numpy ndarray, -// however, not all Dtypes are supported -func (dt Dtype) numpyDtype() (string, error) { - retVal, ok := numpyDtypes[dt] - if !ok { - return "v", errors.Errorf("Unsupported Dtype conversion to Numpy Dtype: %v", dt) - } - return retVal, nil -} +// Dtype is an alias for dtype.Dtype. This alias is here for backward compatibility purposes, for when users are transitioning out of the older tensor libraries. +type Dtype = dtype.Dtype -func fromNumpyDtype(t string) (Dtype, error) { - retVal, ok := reverseNumpyDtypes[t] - if !ok { - return Dtype{}, errors.Errorf("Unsupported Dtype conversion from %q to Dtype", t) - } - if t == "i4" && Int.Size() == 4 { - return Int, nil - } - if t == "i8" && Int.Size() == 8 { - return Int, nil - } - if t == "u4" && Uint.Size() == 4 { - return Uint, nil - } - if t == "u8" && Uint.Size() == 8 { - return Uint, nil - } - return retVal, nil -} - -type typeclass struct { - name string - set []Dtype -} +// nil type class for skipping type class checks +var nilTC dtype.TypeClass = -1 var parameterizedKinds = [...]reflect.Kind{ reflect.Array, @@ -119,204 +34,31 @@ func isParameterizedKind(k reflect.Kind) bool { return false } -// oh how nice it'd be if I could make them immutable -var ( - Bool = Dtype{reflect.TypeOf(true)} - Int = Dtype{reflect.TypeOf(int(1))} - Int8 = Dtype{reflect.TypeOf(int8(1))} - Int16 = Dtype{reflect.TypeOf(int16(1))} - Int32 = Dtype{reflect.TypeOf(int32(1))} - Int64 = Dtype{reflect.TypeOf(int64(1))} - Uint = Dtype{reflect.TypeOf(uint(1))} - Uint8 = Dtype{reflect.TypeOf(uint8(1))} - Uint16 = Dtype{reflect.TypeOf(uint16(1))} - Uint32 = Dtype{reflect.TypeOf(uint32(1))} - Uint64 = Dtype{reflect.TypeOf(uint64(1))} - Float32 = Dtype{reflect.TypeOf(float32(1))} - Float64 = Dtype{reflect.TypeOf(float64(1))} - Complex64 = Dtype{reflect.TypeOf(complex64(1))} - Complex128 = Dtype{reflect.TypeOf(complex128(1))} - String = Dtype{reflect.TypeOf("")} - - // aliases - Byte = Uint8 +func isFloat(dt dtype.Dtype) bool { return dt == Float64 || dt == Float32 } - // extras - Uintptr = Dtype{reflect.TypeOf(uintptr(0))} - UnsafePointer = Dtype{reflect.TypeOf(unsafe.Pointer(&Uintptr))} +// type aliases +var ( + Bool = dtype.Bool + Int = dtype.Int + Int8 = dtype.Int8 + Int16 = dtype.Int16 + Int32 = dtype.Int32 + Int64 = dtype.Int64 + Uint = dtype.Uint + Uint8 = dtype.Uint8 + Uint16 = dtype.Uint16 + Uint32 = dtype.Uint32 + Uint64 = dtype.Uint64 + Float32 = dtype.Float32 + Float64 = dtype.Float64 + Complex64 = dtype.Complex64 + Complex128 = dtype.Complex128 + String = dtype.String + Byte = dtype.Byte + Uintptr = dtype.Uintptr + UnsafePointer = dtype.UnsafePointer ) -// allTypes for indexing -var allTypes = &typeclass{ - name: "τ", - set: []Dtype{ - Bool, Int, Int8, Int16, Int32, Int64, Uint, Uint8, Uint16, Uint32, Uint64, Float32, Float64, Complex64, Complex128, String, Uintptr, UnsafePointer, - }, -} - -// specialized types indicate that there are specialized code generated for these types -var specializedTypes = &typeclass{ - name: "Specialized", - set: []Dtype{ - Bool, Int, Int8, Int16, Int32, Int64, Uint, Uint8, Uint16, Uint32, Uint64, Float32, Float64, Complex64, Complex128, String, - }, -} - -var addableTypes = &typeclass{ - name: "Addable", - set: []Dtype{ - Int, Int8, Int16, Int32, Int64, Uint, Uint8, Uint16, Uint32, Uint64, Float32, Float64, Complex64, Complex128, String, - }, -} - -var numberTypes = &typeclass{ - name: "Number", - set: []Dtype{ - Int, Int8, Int16, Int32, Int64, Uint, Uint8, Uint16, Uint32, Uint64, Float32, Float64, Complex64, Complex128, - }, -} - -var ordTypes = &typeclass{ - name: "Ord", - set: []Dtype{ - Int, Int8, Int16, Int32, Int64, Uint, Uint8, Uint16, Uint32, Uint64, Float32, Float64, String, - }, -} - -var eqTypes = &typeclass{ - name: "Eq", - set: []Dtype{ - Bool, Int, Int8, Int16, Int32, Int64, Uint, Uint8, Uint16, Uint32, Uint64, Float32, Float64, Complex64, Complex128, String, Uintptr, UnsafePointer, - }, -} - -var unsignedTypes = &typeclass{ - name: "Unsigned", - set: []Dtype{Uint, Uint8, Uint16, Uint32, Uint64}, -} - -var signedTypes = &typeclass{ - name: "Signed", - set: []Dtype{ - Int, Int8, Int16, Int32, Int64, Float32, Float64, Complex64, Complex128, - }, -} - -// this typeclass is ever only used by Sub tests -var signedNonComplexTypes = &typeclass{ - name: "Signed NonComplex", - set: []Dtype{ - Int, Int8, Int16, Int32, Int64, Float32, Float64, - }, -} - -var floatTypes = &typeclass{ - name: "Float", - set: []Dtype{ - Float32, Float64, - }, -} - -var complexTypes = &typeclass{ - name: "Complex Numbers", - set: []Dtype{Complex64, Complex128}, -} - -var floatcmplxTypes = &typeclass{ - name: "Real", - set: []Dtype{ - Float32, Float64, Complex64, Complex128, - }, -} - -var nonComplexNumberTypes = &typeclass{ - name: "Non complex numbers", - set: []Dtype{ - Int, Int8, Int16, Int32, Int64, Uint, Uint8, Uint16, Uint32, Uint64, Float32, Float64, - }, -} - -// this typeclass is ever only used by Pow tests -var generatableTypes = &typeclass{ - name: "Generatable types", - set: []Dtype{ - Bool, Int, Int8, Int16, Int32, Int64, Uint, Uint8, Uint16, Uint32, Uint64, Float32, Float64, String, - }, -} - -func isFloat(dt Dtype) bool { - return dt == Float64 || dt == Float32 -} - -func typeclassCheck(a Dtype, tc *typeclass) error { - if tc == nil { - return nil - } - for _, s := range tc.set { - if s == a { - return nil - } - } - return errors.Errorf("Type %v is not a member of %v", a, tc.name) -} - -// RegisterNumber is a function required to register a new numerical Dtype. -// This package provides the following Dtype: -// Int -// Int8 -// Int16 -// Int32 -// Int64 -// Uint -// Uint8 -// Uint16 -// Uint32 -// Uint64 -// Float32 -// Float64 -// Complex64 -// Complex128 -// -// If a Dtype that is registered already exists on the list, it will not be added to the list. -func RegisterNumber(a Dtype) { - for _, dt := range numberTypes.set { - if dt == a { - return - } - } - numberTypes.set = append(numberTypes.set, a) - RegisterEq(a) -} - -func RegisterFloat(a Dtype) { - for _, dt := range floatTypes.set { - if dt == a { - return - } - } - floatTypes.set = append(floatTypes.set, a) - RegisterNumber(a) - RegisterOrd(a) -} - -func RegisterOrd(a Dtype) { - for _, dt := range ordTypes.set { - if dt == a { - return - } - } - ordTypes.set = append(ordTypes.set, a) -} - -func RegisterEq(a Dtype) { - for _, dt := range eqTypes.set { - if dt == a { - return - } - } - eqTypes.set = append(eqTypes.set, a) -} - // NormOrder represents the order of the norm. Ideally, we'd only represent norms with a uint/byte. // But there are norm types that are outside numerical types, such as nuclear norm and fobenius norm. // So it is internally represented by a float. If Go could use NaN and Inf as consts, it would have been best, @@ -387,54 +129,3 @@ func (n NormOrder) String() string { } panic("unreachable") } - -// FuncOpt are optionals for calling Tensor function. -type FuncOpt func(*OpOpt) - -// WithIncr passes in a Tensor to be incremented. -func WithIncr(incr Tensor) FuncOpt { - f := func(opt *OpOpt) { - opt.incr = incr - } - return f -} - -// WithReuse passes in a Tensor to be reused. -func WithReuse(reuse Tensor) FuncOpt { - f := func(opt *OpOpt) { - opt.reuse = reuse - } - return f -} - -// UseSafe ensures that the operation is a safe operation (copies data, does not clobber). This is the default option for most methods and functions -func UseSafe() FuncOpt { - f := func(opt *OpOpt) { - opt.unsafe = false - } - return f -} - -// UseUnsafe ensures that the operation is an unsafe operation - data will be clobbered, and operations performed inplace -func UseUnsafe() FuncOpt { - f := func(opt *OpOpt) { - opt.unsafe = true - } - return f -} - -// AsSameType makes sure that the return Tensor is the same type as input Tensors. -func AsSameType() FuncOpt { - f := func(opt *OpOpt) { - opt.same = true - } - return f -} - -// As makes sure that the the return Tensor is of the type specified. Currently only works for FromMat64 -func As(t Dtype) FuncOpt { - f := func(opt *OpOpt) { - opt.t = t - } - return f -} diff --git a/unsafe.go b/unsafe.go new file mode 100644 index 0000000..5260605 --- /dev/null +++ b/unsafe.go @@ -0,0 +1,3 @@ +package tensor + +import _ "go4.org/unsafe/assume-no-moving-gc" diff --git a/utils.go b/utils.go index 9dcd936..426a1dd 100644 --- a/utils.go +++ b/utils.go @@ -1,6 +1,10 @@ package tensor -import "github.com/pkg/errors" +import ( + "context" + + "github.com/pkg/errors" +) const AllAxes int = -1 @@ -40,35 +44,16 @@ func SumInts(a []int) (retVal int) { // ProdInts returns the internal product of an int slice func ProdInts(a []int) (retVal int) { + retVal = 1 if len(a) == 0 { return } - retVal = 1 for _, v := range a { retVal *= v } return } -// EqInts returns true if slices have same value -// func EqInts(a, b []int) bool { -// if len(a) != len(b) { -// return false -// } - -// if (a == nil) != (b == nil) { -// return false -// } - -// b = b[:len(a)] -// for i, v := range a { -// if v != b[i] { -// return false -// } -// } -// return true -// } - // IsMonotonicInts returns true if the slice of ints is monotonically increasing. It also returns true for incr1 if every succession is a succession of 1 func IsMonotonicInts(a []int) (monotonic bool, incr1 bool) { var prev int @@ -93,6 +78,14 @@ func IsMonotonicInts(a []int) (monotonic bool, incr1 bool) { // Ltoi is Location to Index. Provide a shape, a strides, and a list of integers as coordinates, and returns the index at which the element is. func Ltoi(shape Shape, strides []int, coords ...int) (at int, err error) { + if shape.IsScalarEquiv() { + for _, v := range coords { + if v != 0 { + return -1, errors.Errorf("Scalar shape only allows 0 as an index") + } + } + return 0, nil + } for i, coord := range coords { if i >= len(shape) { err = errors.Errorf(dimMismatch, len(shape), i) @@ -107,23 +100,16 @@ func Ltoi(shape Shape, strides []int, coords ...int) (at int, err error) { } var stride int - if shape.IsRowVec() { - if i == 0 && len(coords) == 2 { - continue - } - stride = strides[0] - } else if shape.IsColVec() { - if i == 1 && len(coords) == 2 { - continue - } + switch { + case shape.IsVector() && len(strides) == 1: stride = strides[0] - } else { - if i >= len(strides) { - err = errors.Errorf(dimMismatch, len(strides), i) - return - } + case i >= len(strides): + err = errors.Errorf(dimMismatch, len(strides), i) + return + default: stride = strides[i] } + at += stride * coord } return at, nil @@ -213,7 +199,6 @@ func UnsafePermute(pattern []int, xs ...[]int) (err error) { return nil } - // CheckSlice checks a slice to see if it's sane func CheckSlice(s Slice, size int) error { start := s.Start() @@ -261,38 +246,35 @@ func SliceDetails(s Slice, size int) (start, end, step int, err error) { return } -// reuseDenseCheck checks a reuse tensor, and reshapes it to be the correct one -func reuseDenseCheck(reuse DenseTensor, as DenseTensor) (err error) { - if reuse.DataSize() != as.Size() { - err = errors.Errorf("Reused Tensor %p does not have expected shape %v. Got %v instead. Reuse Size: %v, as Size %v (real: %d)", reuse, as.Shape(), reuse.Shape(), reuse.DataSize(), as.Size(), as.DataSize()) - return - } - return reuseCheckShape(reuse, as.Shape()) - -} - -// reuseCheckShape checks the shape and reshapes it to be correct if the size fits but the shape doesn't. -func reuseCheckShape(reuse DenseTensor, s Shape) (err error) { +// checkFixShape checks the shape and reshapes it to be correct if the size fits but the shape doesn't. +func checkFixShape(reuse Tensor, s Shape) (err error) { throw := BorrowInts(len(s)) copy(throw, s) - if err = reuse.reshape(throw...); err != nil { - err = errors.Wrapf(err, reuseReshapeErr, s, reuse.DataSize()) + d, ok := reuse.(DenseTensor) + if !ok { + if err = reuse.Reshape(throw...); err != nil { + return errors.Wrapf(err, reuseReshapeErr, s, reuse.DataSize()) + } + return nil + } + + if err = d.reshape(throw...); err != nil { + err = errors.Wrapf(err, reuseReshapeErr, s, d.DataSize()) return } // clean up any funny things that may be in the reuse - if oldAP := reuse.oldAP(); oldAP != nil { - ReturnAP(oldAP) - reuse.setOldAP(nil) + if oldAP := d.oldAP(); !oldAP.IsZero() { + oldAP.zero() } - if axes := reuse.transposeAxes(); axes != nil { + if axes := d.transposeAxes(); axes != nil { ReturnInts(axes) } - if viewOf := reuse.parentTensor(); viewOf != nil { - reuse.setParentTensor(nil) + if viewOf := d.parentTensor(); viewOf != nil { + d.setParentTensor(nil) } return nil } @@ -309,6 +291,45 @@ func memsetBools(a []bool, v bool) { } } +// allones checks that a slice of ints are all 1. +func allones(a []int) bool { + for i := range a { + if a[i] != 1 { + return false + } + } + return true +} + +// ctxFromEngine gets a context from an engine if it's a contexter. Otherwise it returns a context.Background() +func ctxFromEngine(e Engine) context.Context { + if c, ok := e.(contexter); ok { + return c.Context() + } + return context.Background() +} + +func getFloat64s(a Tensor) []float64 { + if um, ok := a.(unsafeMem); ok { + return um.Float64s() + } + return a.Data().([]float64) +} + +func getFloat32s(a Tensor) []float32 { + if um, ok := a.(unsafeMem); ok { + return um.Float32s() + } + return a.Data().([]float32) +} + +func getInts(a Tensor) []int { + if um, ok := a.(unsafeMem); ok { + return um.Ints() + } + return a.Data().([]int) + +} /* FOR ILLUSTRATIVE PURPOSES */ @@ -385,4 +406,4 @@ func Permute(pattern []int, xs ...[]int) (retVal [][]int, err error) { } return } -*/ \ No newline at end of file +*/