00001 #ifndef PSDDL_PYTHON_DDLWRAPPER_H
00002 #define PSDDL_PYTHON_DDLWRAPPER_H 1
00003
00004 #include <boost/python.hpp>
00005
00006 #include <iostream>
00007 #include <algorithm>
00008 #include <boost/python/to_python_converter.hpp>
00009
00010 #include "ndarray/ndarray.h"
00011 #include "psddl_python/psddl_python_numpy.h"
00012
00013 namespace psddl_python {
00014 namespace detail {
00015
00016
00017 template <typename CTYPE>
00018 PyObject*
00019 ndToList(const ndarray<const CTYPE, 1>& a) {
00020 PyObject* res = PyList_New(a.shape()[0]);
00021 for (size_t i = 0; i != a.shape()[0]; ++ i) {
00022 boost::python::object elem(a[i]);
00023 Py_INCREF(elem.ptr());
00024 PyList_SET_ITEM(res, i, elem.ptr());
00025 }
00026 return res;
00027 }
00028
00029
00030 inline
00031 PyObject*
00032 vintToList(const std::vector<int>& v) {
00033 PyObject* res = PyList_New(v.size());
00034 for (size_t i = 0; i != v.size(); ++ i) {
00035 PyList_SET_ITEM(res, i, PyInt_FromLong(v[i]));
00036 }
00037 return res;
00038 }
00039
00040
00041 template <typename T> struct PyArrayTraits {};
00042 #define ASSOCIATE_PYARRAYTYPE(CTYPE, PTYPE) template <> struct PyArrayTraits<CTYPE> { enum { type_code = PTYPE }; };
00043 ASSOCIATE_PYARRAYTYPE(int8_t, PyArray_BYTE);
00044 ASSOCIATE_PYARRAYTYPE(uint8_t, PyArray_UBYTE);
00045 ASSOCIATE_PYARRAYTYPE(int16_t, PyArray_SHORT);
00046 ASSOCIATE_PYARRAYTYPE(uint16_t, PyArray_USHORT);
00047 ASSOCIATE_PYARRAYTYPE(int32_t, PyArray_INT);
00048 ASSOCIATE_PYARRAYTYPE(uint32_t, PyArray_UINT);
00049 ASSOCIATE_PYARRAYTYPE(float, PyArray_FLOAT);
00050 ASSOCIATE_PYARRAYTYPE(double, PyArray_DOUBLE);
00051 #undef ASSOCIATE_PYARRAYTYPE
00052
00053
00054 template <typename T, unsigned NDim>
00055 void _ndarray_dtor(void* ptr)
00056 {
00057 delete static_cast<ndarray<const T, NDim>*>(ptr);
00058 }
00059
00060
00061 template <typename T, unsigned NDim, typename U>
00062 PyObject*
00063 ndToNumpy(const ndarray<const T, NDim>& array, const boost::shared_ptr<U>& owner)
00064 {
00065
00066
00067 ndarray<const T, NDim>* copy = new ndarray<const T, NDim>(array);
00068 if (not owner) {
00069
00070 *copy = copy->copy();
00071 }
00072 PyObject* ndarr = PyCObject_FromVoidPtr(static_cast<void*>(copy), _ndarray_dtor<T, NDim>);
00073
00074
00075 const unsigned* shape = copy->shape();
00076 npy_intp dims[NDim];
00077 std::copy(shape, shape+NDim, dims);
00078 PyObject* nparr = PyArray_SimpleNewFromData(NDim, dims, PyArrayTraits<T>::type_code, (void*)copy->data());
00079
00080
00081 const int* strides = copy->strides();
00082 for (unsigned i = 0; i != NDim; ++ i) PyArray_STRIDES(nparr)[i] = strides[i]*sizeof(T);
00083
00084
00085 ((PyArrayObject*)nparr)->base = ndarr;
00086
00087
00088 PyArray_FLAGS(nparr) &= ~NPY_WRITEABLE;
00089
00090 return nparr;
00091 }
00092
00093
00094 template <typename T, unsigned NDim>
00095 PyObject*
00096 ndToNumpy(const ndarray<const T, NDim>& array)
00097 {
00098 return ndToNumpy(array, boost::shared_ptr<void>());
00099 }
00100
00101
00102 template <typename T, unsigned NDim>
00103 struct ndarray_to_numpy_cvt {
00104 static PyObject* convert(const ndarray<T, NDim>& x) { return ndToNumpy(x); }
00105 static PyTypeObject const* get_pytype() { return &PyArray_Type; }
00106 };
00107
00108 template <typename T, unsigned NDim>
00109 void
00110 register_ndarray_to_numpy_cvt()
00111 {
00112
00113 typedef ndarray<T, NDim> ndtype;
00114 typedef ndarray_to_numpy_cvt<T, NDim> cvttype;
00115 boost::python::type_info tinfo = boost::python::type_id<ndtype>();
00116 boost::python::converter::registration const* reg = boost::python::converter::registry::query(tinfo);
00117 if (not reg or not reg->m_to_python) {
00118 boost::python::to_python_converter<ndtype, cvttype, true>();
00119 }
00120 }
00121
00122
00123 template <typename T>
00124 struct ndarray_to_list_cvt {
00125 static PyObject* convert(const ndarray<T, 1>& x) { return ndToList(x); }
00126 static PyTypeObject const* get_pytype() { return &PyList_Type; }
00127 };
00128
00129 template <typename T>
00130 void
00131 register_ndarray_to_list_cvt()
00132 {
00133
00134 typedef ndarray<T, 1> ndtype;
00135 typedef ndarray_to_list_cvt<T> cvttype;
00136 boost::python::type_info tinfo = boost::python::type_id<ndtype>();
00137 boost::python::converter::registration const* reg = boost::python::converter::registry::query(tinfo);
00138 if (not reg or not reg->m_to_python) {
00139 boost::python::to_python_converter<ndtype, cvttype, true>();
00140 }
00141 }
00142
00143 }
00144 }
00145
00146 #endif // PSDDL_PYTHON_DDLWRAPPER_H