Skip to content

Commit e62411d

Browse files
committed
adding node constructor
1 parent 3925a6d commit e62411d

File tree

10 files changed

+124
-5
lines changed

10 files changed

+124
-5
lines changed

include/node/node_factory.hpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
#ifndef NODE_FACTORY_HPP
2+
#define NODE_FACTORY_HPP
3+
4+
# include "node/node.hpp"
5+
# include "array/array.hpp"
6+
7+
template <typename T> [[gnu::used]]
8+
std::shared_ptr<Node> newNode(
9+
const std::string& name = "",
10+
const std::string& type = "",
11+
const T& data = false,
12+
std::shared_ptr<Node> parent = nullptr);
13+
14+
#endif

src/c++/node/node.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,6 @@ std::shared_ptr<Data> Node::dataPtr() const {
128128
return this->_data;
129129
}
130130

131-
132131
void Node::setData(std::shared_ptr<Data> d) {
133132
this->_data = std::move(d);
134133
}

src/c++/node/node_factory.cpp

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
# include "node/node_factory.hpp"
2+
3+
4+
template <typename T> [[gnu::used]]
5+
std::shared_ptr<Node> newNode(
6+
const std::string& name,
7+
const std::string& type,
8+
const T& data,
9+
std::shared_ptr<Node> parent) {
10+
11+
std::shared_ptr<Node> node = std::make_shared<Node>();
12+
node->setName(name);
13+
node->setType(type);
14+
Array dataArray(data);
15+
node->setData(dataArray);
16+
if (parent) node->attachTo(parent);
17+
18+
return node;
19+
}
20+
21+
/*
22+
template instantiations
23+
*/
24+
25+
template <typename... T>
26+
struct Instantiator {
27+
template <typename... U>
28+
void operator()() const {
29+
(static_cast<void>(newNode<U>(std::string{},std::string{},T{},std::shared_ptr<Node>{})), ...);
30+
}
31+
};
32+
33+
template void utils::instantiateFromTypeList<Instantiator, utils::StringAndScalarTypes>();

tests/c++/array/test_array.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -140,12 +140,12 @@ void test_isScalar() {
140140
throw py::value_error("should have been detected as scalar");
141141
}
142142

143-
Array directScalarArray = Array(12);
143+
T scalar(1);
144+
Array directScalarArray = Array(scalar);
144145
if (!directScalarArray.isScalar()) {
145146
throw py::value_error("should have been detected as scalar");
146147
}
147148

148-
149149
Array vectorArray = arrayfactory::zeros<T>({2});
150150
if (vectorArray.isScalar()) {
151151
throw py::value_error("should not have been detected as scalar");
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
# include "test_node_factory.hpp"
2+
3+
using namespace std::string_literals;
4+
5+
void test_newNodeNoArgs() {
6+
// call without arguments requires type instantiation since
7+
// it cannot guess data type
8+
auto node = newNode<int>();
9+
}
10+
11+
void test_newNodeOnlyName() {
12+
auto node = newNode<int>("TheName"s);
13+
}
14+
15+
void test_newNodeNameAndType() {
16+
auto node = newNode<int>("TheName"s, "TheType"s);
17+
}
18+
19+
void test_newNodeNameTypeAndData() {
20+
// should be able to automatically find type
21+
auto node = newNode("TheName"s, "TheType"s, 1);
22+
}
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
#ifndef TEST_NODE_FACTORY_H
2+
#define TEST_NODE_FACTORY_H
3+
4+
# include <node/node.hpp>
5+
# include <node/node_factory.hpp>
6+
7+
void test_newNodeNoArgs();
8+
void test_newNodeOnlyName();
9+
void test_newNodeNameAndType();
10+
void test_newNodeNameTypeAndData();
11+
12+
#endif
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# ifndef TEST_NODE_FACTORY_PYBIND_HPP
2+
# define TEST_NODE_FACTORY_PYBIND_HPP
3+
4+
# include <pybind11/pybind11.h>
5+
6+
# include "test_node_factory.hpp"
7+
8+
void bindTestsOfNodeFactory(py::module_ &m) {
9+
10+
py::module_ sm = m.def_submodule("node_factory");
11+
12+
sm.def("test_newNodeNoArgs", &test_newNodeNoArgs);
13+
sm.def("test_newNodeOnlyName", &test_newNodeOnlyName);
14+
sm.def("test_newNodeNameAndType", &test_newNodeNameAndType);
15+
sm.def("test_newNodeNameTypeAndData", &test_newNodeNameTypeAndData);
16+
17+
18+
19+
}
20+
21+
# endif

tests/c++/test_core_pybind.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
# include "array/test_array_pybind.hpp"
88
# include "node/test_node_pybind.hpp"
9+
# include "node/test_node_factory_pybind.hpp"
910
# include "data/data_factory.hpp"
1011
# include "node/test_data_pybind.hpp"
1112
# include "node/test_navigation_pybind.hpp"
@@ -27,6 +28,7 @@ PYBIND11_MODULE(tests, m) {
2728

2829
bindTestsOfArray(m);
2930
bindTestsOfNode(m);
31+
bindTestsOfNodeFactory(m);
3032
bindTestsOfData(m);
3133
bindTestsOfNavigation(m);
3234

tests/python/array/test_array.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,8 +90,14 @@ def test_isScalar(dtype):
9090
scalarArray = zeros_builder([1],'C')
9191
assert scalarArray.isScalar()
9292

93-
directScalarArray = Array(12)
94-
assert directScalarArray.isScalar()
93+
directBooleanArray = Array(True)
94+
assert directBooleanArray.isScalar()
95+
96+
directIntegerArray = Array(12)
97+
assert directIntegerArray.isScalar()
98+
99+
directFloatArray = Array(3.14159)
100+
assert directFloatArray.isScalar()
95101

96102
vectorArray = zeros_builder([2],'C')
97103
assert not vectorArray.isScalar()
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
import pytest
2+
import noder.tests.node_factory as test_in_cpp
3+
4+
def test_cpp_newNodeNoArgs(): return test_in_cpp.test_newNodeNoArgs()
5+
def test_cpp_newNodeOnlyName(): return test_in_cpp.test_newNodeOnlyName()
6+
def test_cpp_newNodeNameAndType(): return test_in_cpp.test_newNodeNameAndType()
7+
def test_cpp_newNodeNameTypeAndData(): return test_in_cpp.test_newNodeNameTypeAndData()
8+
9+
if __name__ == '__main__':
10+
test_cpp_newNodeNoArgs()

0 commit comments

Comments
 (0)