-
Notifications
You must be signed in to change notification settings - Fork 90
/
Copy pathmain.cpp
81 lines (66 loc) · 1.44 KB
/
main.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
#include <vector>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
namespace py = pybind11;
template <class Derived>
class BaseImpl
{
public:
template <class T>
T myfunc(const T& t)
{
return static_cast<Derived*>(this)->myfunc_impl(t);
}
private:
template <class T>
T myfunc_impl(const T& t)
{
T ret = t;
for (auto& i : ret) {
i *= 2.0;
}
return ret;
}
};
class Base : public BaseImpl<Base>
{
public:
Base() = default;
private:
friend class BaseImpl<Base>;
};
class Derived : public BaseImpl<Derived>
{
public:
Derived() = default;
private:
template <class T>
T myfunc_impl(const T& t)
{
T ret = t;
for (auto& i : ret) {
i *= 3.0;
}
return ret;
}
private:
friend class BaseImpl<Derived>;
};
template <class T, class M>
auto registerBaseImpl(M& self)
{
self.def(py::init<>())
.def("myfunc", &T::template myfunc<std::vector<double>>);
}
PYBIND11_MODULE(mymodule, m)
{
m.doc() = "CRTP example";
py::class_<BaseImpl<Base>> BaseBase(m, "BaseBase");
py::class_<BaseImpl<Derived>> BaseDerived(m, "BaseDerived");
registerBaseImpl<BaseImpl<Base>>(BaseBase);
registerBaseImpl<BaseImpl<Derived>>(BaseDerived);
py::class_<Base, BaseImpl<Base>>(m, "Base")
.def(py::init<>());
py::class_<Derived, BaseImpl<Derived>>(m, "Derived")
.def(py::init<>());
}