Skip to content

Commit

Permalink
DetectorMap: add vectorized findPoint,findWavelength
Browse files Browse the repository at this point in the history
Iterating in C++ is much faster than in python, even when there's a virtual
function call involved. However, pybind11 doesn't like us having a virtual
function that takes scalars and a non-virtual function that takes vectors
and iteratively calls the scalar version (because the virtual function has
to be re-defined in the subclass, which then doesn't look to the base
class for overloads). Instead, use a different name for the virtual
function, and define all the functions in the base class. This way, there's
no overloads to confuse pybind: everything's handled in C++.
  • Loading branch information
PaulPrice committed Nov 13, 2020
1 parent 435ab49 commit 9ee8c6a
Show file tree
Hide file tree
Showing 8 changed files with 119 additions and 56 deletions.
38 changes: 36 additions & 2 deletions include/pfs/drp/stella/DetectorMap.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,35 @@ class DetectorMap : public lsst::afw::table::io::Persistable {
/// Return the fiberId given a position on the detector
virtual int findFiberId(lsst::geom::PointD const& point) const = 0;

//@{
/// Return the position of the fiber trace on the detector, given a fiberId and wavelength
virtual lsst::geom::PointD findPoint(int fiberId, float wavelength) const = 0;
lsst::geom::PointD findPoint(int fiberId, float wavelength) const {
return findPointImpl(fiberId, wavelength);
}
ndarray::Array<float, 2, 1> findPoint(
int fiberId,
ndarray::Array<float, 1, 1> const& wavelength
) const;
ndarray::Array<float, 2, 1> findPoint(
ndarray::Array<int, 1, 1> const& fiberId,
ndarray::Array<float, 1, 1> const& wavelength
) const;
//@}

//@{
/// Return the wavelength of a point on the detector, given a fiberId and row
virtual float findWavelength(int fiberId, float row) const = 0;
float findWavelength(int fiberId, float row) const {
return findWavelengthImpl(fiberId, row);
}
ndarray::Array<float, 1, 1> findWavelength(
int fiberId,
ndarray::Array<float, 1, 1> const& row
) const;
ndarray::Array<float, 1, 1> findWavelength(
ndarray::Array<int, 1, 1> const& fiberId,
ndarray::Array<float, 1, 1> const& row
) const;
//@}

VisitInfo getVisitInfo() const { return _visitInfo; }
void setVisitInfo(VisitInfo &visitInfo) { _visitInfo = visitInfo; };
Expand Down Expand Up @@ -110,6 +134,16 @@ class DetectorMap : public lsst::afw::table::io::Persistable {
/// Return the index of a fiber, given its fiber ID
std::size_t getFiberIndex(int fiberId) const { return _fiberMap.at(fiberId); }

/// Return the position of the fiber trace on the detector, given a fiberId and wavelength
///
/// Implementation of findPoint, for subclasses to define.
virtual lsst::geom::PointD findPointImpl(int fiberId, float wavelength) const = 0;

/// Return the wavelength of a point on the detector, given a fiberId and row
///
/// Implementation of findWavelength, for subclasses to define.
virtual float findWavelengthImpl(int fiberId, float row) const = 0;

/// Reset cached elements after setting slit offsets
virtual void _resetSlitOffsets() {}

Expand Down
15 changes: 6 additions & 9 deletions include/pfs/drp/stella/GlobalDetectorMap.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,15 +92,6 @@ class GlobalDetectorMap : public DetectorMap {
/// Return the fiberId given a position on the detector
virtual int findFiberId(lsst::geom::PointD const& point) const override;

//@{
/// Return the position of the fiber trace on the detector, given a fiberId and wavelength
virtual lsst::geom::PointD findPoint(int fiberId, float wavelength) const override;
virtual Array2D findPoint(FiberIds const& fiberId, Array1D const& wavelength) const;
//@}

/// Return the wavelength of a point on the detector, given a fiberId and row
virtual float findWavelength(int fiberId, float row) const override;

GlobalDetectorModel getModel() const { return _model; }
int getDistortionOrder() const { return _model.getDistortionOrder(); }

Expand All @@ -109,6 +100,12 @@ class GlobalDetectorMap : public DetectorMap {
class Factory;

protected:
/// Return the position of the fiber trace on the detector, given a fiberId and wavelength
virtual lsst::geom::PointD findPointImpl(int fiberId, float wavelength) const override;

/// Return the wavelength of a point on the detector, given a fiberId and row
virtual float findWavelengthImpl(int fiberId, float row) const override;

std::string getPersistenceName() const { return "GlobalDetectorMap"; }
std::string getPythonModule() const { return "pfs.drp.stella"; }
void write(lsst::afw::table::io::OutputArchiveHandle & handle) const;
Expand Down
12 changes: 6 additions & 6 deletions include/pfs/drp/stella/SplinedDetectorMap.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,12 +90,6 @@ class SplinedDetectorMap : public DetectorMap {
/// Return the fiberId given a position on the detector
virtual int findFiberId(lsst::geom::PointD const& point) const override;

/// Return the position of the fiber trace on the detector, given a fiberId and wavelength
virtual lsst::geom::PointD findPoint(int fiberId, float wavelength) const override;

/// Return the wavelength of a point on the detector, given a fiberId and row
virtual float findWavelength(int fiberId, float row) const override;

math::Spline<float> const& getXCenterSpline(int fiberId) const;
math::Spline<float> const& getWavelengthSpline(int fiberId) const;

Expand All @@ -110,6 +104,12 @@ class SplinedDetectorMap : public DetectorMap {
class Factory;

protected:
/// Return the position of the fiber trace on the detector, given a fiberId and wavelength
virtual lsst::geom::PointD findPointImpl(int fiberId, float wavelength) const override;

/// Return the wavelength of a point on the detector, given a fiberId and row
virtual float findWavelengthImpl(int fiberId, float row) const override;

std::string getPersistenceName() const { return "SplinedDetectorMap"; }
std::string getPythonModule() const { return "pfs.drp.stella"; }
void write(lsst::afw::table::io::OutputArchiveHandle & handle) const;
Expand Down
17 changes: 0 additions & 17 deletions include/pfs/drp/stella/python/DetectorMap.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,23 +32,6 @@ auto wrapDetectorMap(py::module & mod, char const* name) {
"fiberId"_a, "row"_a);
cls.def("findFiberId", py::overload_cast<lsst::geom::PointD const&>(&Class::findFiberId, py::const_),
"point"_a);
cls.def("findPoint", py::overload_cast<int, float>(&Class::findPoint, py::const_),
"fiberId"_a, "wavelength"_a);
cls.def("findPoint",
[](Class const& self, ndarray::Array<int, 1, 1> const& fiberId,
ndarray::Array<float, 1, 1> const& wavelength) {
std::size_t const num = fiberId.size();
utils::checkSize(wavelength.size(), num, "fiberId/wavelength");
ndarray::Array<float, 2, 1> out = ndarray::allocate(num, 2);
for (std::size_t ii = 0; ii < num; ++ii) {
auto const point = self.findPoint(fiberId[ii], wavelength[ii]);
out[ii][0] = point.getX();
out[ii][1] = point.getY();
}
return out;
}, "fiberId"_a, "wavelength"_a);
cls.def("findWavelength", py::overload_cast<int, float>(&Class::findWavelength, py::const_),
"fiberId"_a, "row"_a);
return cls;
}

Expand Down
11 changes: 9 additions & 2 deletions python/pfs/drp/stella/DetectorMap.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,16 +36,23 @@ void declareDetectorMap(py::module & mod) {
"fiberId"_a, "spatial"_a, "spectral"_a);
cls.def("getXCenter", py::overload_cast<>(&Class::getXCenter, py::const_));
cls.def("getWavelength", py::overload_cast<>(&Class::getWavelength, py::const_));
#if 0
cls.def("findPoint", py::overload_cast<int, float>(&Class::findPoint, py::const_),
"fiberId"_a, "wavelength"_a);
cls.def("findPoint", py::overload_cast<int, Class::Array1D const&>(&Class::findPoint, py::const_),
"fiberId"_a, "wavelength"_a);
cls.def("findPoint",
py::overload_cast<Class::FiberIds const&,
Class::Array1D const&>(&Class::findPoint, py::const_),
"fiberId"_a, "wavelength"_a);
cls.def("findWavelength", py::overload_cast<int, float>(&Class::findWavelength, py::const_),
"fiberId"_a, "row"_a);
cls.def("findWavelength",
py::overload_cast<int, Class::Array1D const&>(&Class::findWavelength, py::const_),
"fiberId"_a, "row"_a);
cls.def("findWavelength",
py::overload_cast<Class::FiberIds const&,
Class::Array1D const&>(&Class::findWavelength, py::const_),
"fiberId"_a, "row"_a);
#endif
cls.def("getVisitInfo", &Class::getVisitInfo);
cls.def("setVisitInfo", &Class::setVisitInfo, "visitInfo"_a);
cls.def_property("visitInfo", &Class::getVisitInfo, &Class::setVisitInfo);
Expand Down
58 changes: 58 additions & 0 deletions src/DetectorMap.cc
Original file line number Diff line number Diff line change
Expand Up @@ -109,4 +109,62 @@ DetectorMap::Array2D DetectorMap::getXCenter() const {
}


ndarray::Array<float, 2, 1> DetectorMap::findPoint(
int fiberId,
ndarray::Array<float, 1, 1> const& wavelength
) const {
std::size_t const length = wavelength.size();
ndarray::Array<float, 2, 1> out = ndarray::allocate(length, 2);
for (std::size_t ii = 0; ii < length; ++ii) {
auto const point = findPointImpl(fiberId, wavelength[ii]);
out[ii][0] = point.getX();
out[ii][1] = point.getY();
}
return out;
}


ndarray::Array<float, 2, 1> DetectorMap::findPoint(
ndarray::Array<int, 1, 1> const& fiberId,
ndarray::Array<float, 1, 1> const& wavelength
) const {
std::size_t const length = fiberId.size();
utils::checkSize(length, wavelength.size(), "fiberId vs wavelength");
ndarray::Array<float, 2, 1> out = ndarray::allocate(length, 2);
for (std::size_t ii = 0; ii < length; ++ii) {
auto const point = findPointImpl(fiberId[ii], wavelength[ii]);
out[ii][0] = point.getX();
out[ii][1] = point.getY();
}
return out;
}


ndarray::Array<float, 1, 1> DetectorMap::findWavelength(
int fiberId,
ndarray::Array<float, 1, 1> const& row
) const {
std::size_t const length = row.size();
ndarray::Array<float, 1, 1> out = ndarray::allocate(length);
for (std::size_t ii = 0; ii < length; ++ii) {
out[ii] = findWavelengthImpl(fiberId, row[ii]);
}
return out;
}


ndarray::Array<float, 1, 1> DetectorMap::findWavelength(
ndarray::Array<int, 1, 1> const& fiberId,
ndarray::Array<float, 1, 1> const& row
) const {
std::size_t const length = fiberId.size();
utils::checkSize(length, row.size(), "fiberId vs row");
ndarray::Array<float, 1, 1> out = ndarray::allocate(length);
for (std::size_t ii = 0; ii < length; ++ii) {
out[ii] = findWavelengthImpl(fiberId[ii], row[ii]);
}
return out;
}


}}} // namespace pfs::drp::stella
20 changes: 2 additions & 18 deletions src/GlobalDetectorMap.cc
Original file line number Diff line number Diff line change
Expand Up @@ -221,31 +221,15 @@ int GlobalDetectorMap::findFiberId(lsst::geom::PointD const& point) const {
}


lsst::geom::PointD GlobalDetectorMap::findPoint(
lsst::geom::PointD GlobalDetectorMap::findPointImpl(
int fiberId,
float wavelength
) const {
return _model(fiberId, wavelength);
}


GlobalDetectorMap::Array2D GlobalDetectorMap::findPoint(
GlobalDetectorMap::FiberIds const& fiberId,
GlobalDetectorMap::Array1D const& wavelength
) const {
std::size_t const length = fiberId.size();
utils::checkSize(wavelength.size(), length, "wavelength");
Array2D out = ndarray::allocate(2, length);
for (std::size_t ii = 0; ii < length; ++ii) {
auto const point = _model(fiberId[ii], wavelength[ii]);
out[0][ii] = point.getX();
out[1][ii] = point.getY();
}
return out;
}


float GlobalDetectorMap::findWavelength(int fiberId, float row) const {
float GlobalDetectorMap::findWavelengthImpl(int fiberId, float row) const {
Spline const& spline = _rowToWavelength[getFiberIndex(fiberId)];
return spline(row);
}
Expand Down
4 changes: 2 additions & 2 deletions src/SplinedDetectorMap.cc
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ int SplinedDetectorMap::findFiberId(
}


lsst::geom::PointD SplinedDetectorMap::findPoint(
lsst::geom::PointD SplinedDetectorMap::findPointImpl(
int fiberId,
float wavelength
) const {
Expand Down Expand Up @@ -231,7 +231,7 @@ lsst::geom::PointD SplinedDetectorMap::findPoint(
}


float SplinedDetectorMap::findWavelength(
float SplinedDetectorMap::findWavelengthImpl(
int fiberId,
float row
) const {
Expand Down

0 comments on commit 9ee8c6a

Please sign in to comment.