ROSE  0.11.145.0
util/Sawyer/Database.h
1 // WARNING: Changes to this file must be contributed back to Sawyer or else they will
2 // be clobbered by the next update from Sawyer. The Sawyer repository is at
3 // https://github.com/matzke1/sawyer.
4 
5 
6 
7 
8 #ifndef Sawyer_Database_H
9 #define Sawyer_Database_H
10 
11 #if __cplusplus >= 201103L
12 
13 #include <boost/iterator/iterator_facade.hpp>
14 #include <boost/lexical_cast.hpp>
15 #include <boost/numeric/conversion/cast.hpp>
16 #include <memory.h>
17 #include <Sawyer/Assert.h>
18 #include <Sawyer/Map.h>
19 #include <Sawyer/Optional.h>
20 #include <string>
21 #include <vector>
22 
23 namespace Sawyer {
24 
167 namespace Database {
168 
169 class Connection;
170 class Statement;
171 class Row;
172 class Iterator;
173 
174 namespace Detail {
175  class ConnectionBase;
176  class StatementBase;
177 }
178 
179 class Exception: public std::runtime_error {
180 public:
181  Exception(const std::string &what)
182  : std::runtime_error(what) {}
183 
184  ~Exception() noexcept {}
185 };
186 
188 // Connection
190 
196 class Connection {
197  friend class ::Sawyer::Database::Statement;
198  friend class ::Sawyer::Database::Detail::ConnectionBase;
199 
200  std::shared_ptr<Detail::ConnectionBase> pimpl_;
201 
202 public:
204  Connection() {};
205 
206 private:
207  explicit Connection(const std::shared_ptr<Detail::ConnectionBase> &pimpl);
208 
209 public:
213  ~Connection() = default;
214 
216  static Connection fromUri(const std::string &uri);
217 
221  static std::string uriDocString();
222 
224  bool isOpen() const;
225 
231  Connection& close();
232 
247  Statement stmt(const std::string &sql);
248 
250  Connection& run(const std::string &sql);
251 
257  template<typename T>
258  Optional<T> get(const std::string &sql);
259 
263  std::string driverName() const;
264 
265  // Undocumented: Row number for the last SQL "insert" (do not use).
266  //
267  // This method is available only if the underlying database driver supports it and it has lots of caveats. In other words,
268  // don't use this method. The most portable way to identify the rows that were just inserted is to insert a UUID as part of
269  // the data.
270  size_t lastInsert() const;
271 
272  // Set the pointer to implementation
273  void pimpl(const std::shared_ptr<Detail::ConnectionBase> &p) {
274  pimpl_ = p;
275  }
276 };
277 
279 // Statement
281 
286 class Statement {
287  friend class ::Sawyer::Database::Detail::ConnectionBase;
288 
289  std::shared_ptr<Detail::StatementBase> pimpl_;
290 
291 public:
293  enum State {
294  UNBOUND,
295  READY,
296  EXECUTING,
297  FINISHED,
298  DEAD
299  };
300 
301 public:
303  Statement() {}
304 
305 private:
306  explicit Statement(const std::shared_ptr<Detail::StatementBase> &stmt)
307  : pimpl_(stmt) {}
308 
309 public:
311  Connection connection() const;
312 
322  template<typename T>
323  Statement& bind(const std::string &name, const T &value);
324 
329  template<typename T>
330  Statement& rebind(const std::string &name, const T &value);
331 
338  Iterator begin();
339 
341  Iterator end();
342 
347  Statement& run();
348 
353  template<typename T>
354  Optional<T> get();
355 };
356 
358 // Row
360 
365 class Row {
366  friend class ::Sawyer::Database::Iterator;
367 
368  std::shared_ptr<Detail::StatementBase> stmt_;
369  size_t sequence_; // for checking validity
370 
371 private:
372  Row()
373  : sequence_(0) {}
374 
375  explicit Row(const std::shared_ptr<Detail::StatementBase> &stmt);
376 
377 public:
379  template<typename T>
380  Optional<T> get(size_t columnIdx) const;
381 
385  size_t rowNumber() const;
386 };
387 
389 // Iterator
391 
398 class Iterator: public boost::iterator_facade<Iterator, const Row, boost::forward_traversal_tag> {
399  friend class ::Sawyer::Database::Detail::StatementBase;
400 
401  Row row_;
402 
403 public:
405  Iterator() {}
406 
407 private:
408  explicit Iterator(const std::shared_ptr<Detail::StatementBase> &stmt);
409 
410 public:
412  bool isEnd() const {
413  return !row_.stmt_;
414  }
415 
417  explicit operator bool() const {
418  return !isEnd();
419  }
420 
421 private:
422  friend class boost::iterator_core_access;
423  const Row& dereference() const;
424  bool equal(const Iterator&) const;
425  void increment();
426 };
427 
428 
432 //
433 // Only implementation details beyond this point.
434 //
438 
439 
440 namespace Detail {
441 
442 // Base class for connection details. The individual drivers (SQLite3, PostgreSQL) will be derived from this class.
443 //
444 // Connection detail objects are reference counted. References come from only two places:
445 // 1. Each top-level Connection object that's in a connected state has a reference to this connection.
446 // 2. Each low-level Statement object that's in an "executing" state has a reference to this connection.
447 // Additionally, all low-level statement objects have a weak reference to a connection.
448 //
449 class ConnectionBase: public std::enable_shared_from_this<ConnectionBase> {
450  friend class ::Sawyer::Database::Connection;
451 
452 protected:
453  ConnectionBase() {}
454 
455 public:
456  virtual ~ConnectionBase() {}
457 
458 protected:
459  // Close any low-level connection.
460  virtual void close() = 0;
461 
462  // Create a prepared statement from the specified high-level SQL. By "high-level" we mean the binding syntax used by this
463  // API such as "?name" (whereas low-level means the syntax passed to the driver such as "?").
464  virtual Statement prepareStatement(const std::string &sql) = 0;
465 
466  // Row number for the last inserted row if supported by this driver. It's better to use a table column that holds a value
467  // generated from a sequence.
468  virtual size_t lastInsert() const = 0;
469 
470  Statement makeStatement(const std::shared_ptr<Detail::StatementBase> &detail);
471 
472  virtual std::string driverName() const = 0;
473 };
474 
475 // Describes the location of "?name" parameters in high-level SQL by associating them with one or more "?" parameters in
476 // low-level SQL. WARNIN: the low-level parameters are numbered starting at one instead of zero, which is inconsistent with how
477 // the low-level APIs index other things like query result columns (not to mention being surprising for C and C++ developers).
478 class Parameter {
479  friend class ::Sawyer::Database::Detail::StatementBase;
480 
481  std::vector<size_t> indexes; // "?" indexes
482  bool isBound = false;
483 
484  void append(size_t idx) {
485  indexes.push_back(idx);
486  }
487 };
488 
489 template<typename T>
490 class ColumnReader {
491  friend class ::Sawyer::Database::Detail::StatementBase;
492  Optional<T> operator()(StatementBase *stmt, size_t idx);
493 };
494 
495 //template<>
496 //class ColumnReader<std::vector<uint8_t>> {
497 // friend class ::Sawyer::Database::Detail::StatementBase;
498 // Optional<std::vector<uint8_t>> operator()(StatementBase *stmt, size_t idx);
499 //};
500 
501 // Reference counted prepared statement details. Objects of this class are referenced from the high-level Statement objects and
502 // the query iterator rows. This class is the base class for driver-specific statements.
503 class StatementBase: public std::enable_shared_from_this<StatementBase> {
504  friend class ::Sawyer::Database::Iterator;
505  friend class ::Sawyer::Database::Row;
506  friend class ::Sawyer::Database::Statement;
507  template<class T> friend class ::Sawyer::Database::Detail::ColumnReader;
508 
509  using Parameters = Container::Map<std::string, Parameter>;
510 
511  std::shared_ptr<ConnectionBase> connection_; // non-null while statement is executing
512  std::weak_ptr<ConnectionBase> weakConnection_; // refers to the originating connection
513  Parameters params_; // mapping from param names to question marks
514  Statement::State state_ = Statement::DEAD; // don't set directly; use "state" member function
515  size_t sequence_ = 0; // sequence number for invalidating row iterators
516  size_t rowNumber_ = 0; // result row number
517 
518 public:
519  virtual ~StatementBase() {}
520 
521 protected:
522  explicit StatementBase(const std::shared_ptr<ConnectionBase> &connection)
523  : weakConnection_(connection) { // save only a weak pointer, no shared pointer
524  ASSERT_not_null(connection);
525  }
526 
527  // Parse the high-level SQL (with "?name" parameters) into low-level SQL (with "?" parameters). Returns the low-level SQL
528  // and the number of low-level "?" parameters and has the following side effects:
529  // 1. Re-initializes this object's parameter list
530  // 2. Sets this object's state to READY, UNBOUND, or DEAD.
531  std::pair<std::string, size_t> parseParameters(const std::string &highSql) {
532  params_.clear();
533  std::string lowSql;
534  bool inString = false;
535  size_t nLowParams = 0;
536  state(Statement::READY); // possibly reset below
537  for (size_t i = 0; i < highSql.size(); ++i) {
538  if ('\'' == highSql[i]) {
539  inString = !inString; // works for "''" escape too
540  lowSql += highSql[i];
541  } else if ('?' == highSql[i] && !inString) {
542  lowSql += '?';
543  std::string paramName;
544  while (i+1 < highSql.size() && (::isalnum(highSql[i+1]) || '_' == highSql[i+1]))
545  paramName += highSql[++i];
546  if (paramName.empty())
547  throw Exception("invalid parameter name at character position " + boost::lexical_cast<std::string>(i));
548  Parameter &param = params_.insertMaybeDefault(paramName);
549  param.append(nLowParams++); // 0-origin low-level parameter numbers
550  state(Statement::UNBOUND);
551  } else {
552  lowSql += highSql[i];
553  }
554  }
555  if (inString) {
556  state(Statement::DEAD);
557  throw Exception("mismatched quotes in SQL statement");
558  }
559  return std::make_pair(lowSql, nLowParams);
560  }
561 
562  // Invalidate all iterators and their rows by incrementing this statements sequence number.
563  void invalidateIteratorsAndRows() {
564  ++sequence_;
565  }
566 
567  // Sequence number used for checking iterator validity.
568  size_t sequence() const {
569  return sequence_;
570  }
571 
572  // Cause this statement to lock the database connection by maintaining a shared pointer to the low-level
573  // connection. Returns true if the connection could be locked, or false if unable.
574  bool lockConnection() {
575  return (connection_ = weakConnection_.lock()) != nullptr;
576  }
577 
578  // Release the connection lock by throwing away the shared pointer to the connection. This statement will still maintain
579  // a weak reference to the connection.
580  void unlockConnection() {
581  connection_.reset();
582  }
583 
584  // Returns an indication of whether this statement holds a lock on the low-level connection, preventing the connection from
585  // being destroyed.
586  bool isConnectionLocked() const {
587  return connection_ != nullptr;
588  }
589 
590  // Returns the connection details associated with this statement. The connection is not locked by querying this property.
591  std::shared_ptr<ConnectionBase> connection() const {
592  return weakConnection_.lock();
593  }
594 
595  // Return the current statement state.
596  Statement::State state() const {
597  return state_;
598  }
599 
600  // Change the statement state. A statement in the EXECUTING state will lock the connection to prevent it from being
601  // destroyed, but a statement in any other state will unlock the connection causing the last reference to destroy the
602  // connection and will invalidate all iterators and rows.
603  void state(Statement::State newState) {
604  switch (newState) {
605  case Statement::DEAD:
606  case Statement::FINISHED:
607  case Statement::UNBOUND:
608  case Statement::READY:
609  invalidateIteratorsAndRows();
610  unlockConnection();
611  break;
612  case Statement::EXECUTING:
613  ASSERT_require(isConnectionLocked());
614  break;
615  }
616  state_ = newState;
617  }
618 
619  // Returns true if this statement has parameters that have not been bound to a value.
620  bool hasUnboundParameters() const {
621  ASSERT_forbid(state() == Statement::DEAD);
622  for (const Parameter &param: params_.values()) {
623  if (!param.isBound)
624  return true;
625  }
626  return false;
627  }
628 
629  // Causes all parameters to become unbound and changes the state to either UNBOUND or READY (depending on whether there are
630  // any parameters or not, respectively).
631  virtual void unbindAllParams() {
632  ASSERT_forbid(state() == Statement::DEAD);
633  for (Parameter &param: params_.values())
634  param.isBound = false;
635  state(params_.isEmpty() ? Statement::READY : Statement::UNBOUND);
636  }
637 
638  // Reset the statement by invalidating all iterators, unbinding all parameters, and changing the state to either UNBOUND or
639  // READY depending on whether or not it has any parameters.
640  virtual void reset(bool doUnbind) {
641  ASSERT_forbid(state() == Statement::DEAD);
642  invalidateIteratorsAndRows();
643  if (doUnbind) {
644  unbindAllParams();
645  } else {
646  state(hasUnboundParameters() ? Statement::UNBOUND : Statement::READY);
647  }
648  }
649 
650  // Bind a value to a parameter. If isRebind is set and the statement is in the EXECUTING state, then rewind back to the
651  // READY state, preserve all previous bindings, and adjust only the specified binding.
652  template<typename T>
653  void bind(const std::string &name, const T &value, bool isRebind) {
654  if (!connection())
655  throw Exception("connection is closed");
656  switch (state()) {
657  case Statement::DEAD:
658  throw Exception("statement is dead");
659  case Statement::FINISHED:
660  case Statement::EXECUTING:
661  reset(!isRebind);
662  // fall through
663  case Statement::READY:
664  case Statement::UNBOUND: {
665  if (!params_.exists(name))
666  throw Exception("no such parameter \"" + name + "\" in statement");
667  Parameter &param = params_[name];
668  bool wasUnbound = !param.isBound;
669  for (size_t idx: param.indexes) {
670  try {
671  bindLow(idx, value);
672  } catch (const Exception &e) {
673  if (param.indexes.size() > 1)
674  state(Statement::DEAD); // might be only partly bound now
675  throw e;
676  }
677  }
678  param.isBound = true;
679 
680  if (wasUnbound && !hasUnboundParameters())
681  state(Statement::READY);
682  break;
683  }
684  }
685  }
686 
687  // Bind a value to an optional parameter.
688  template<typename T>
689  void bind(const std::string &name, const Sawyer::Optional<T> &value, bool isRebind) {
690  if (value) {
691  bind(name, *value, isRebind);
692  } else {
693  bind(name, Nothing(), isRebind);
694  }
695  }
696 
697  // Driver-specific part of binding by specifying the 0-origin low-level "?" number and the value.
698  virtual void bindLow(size_t idx, int value) = 0;
699  virtual void bindLow(size_t idx, int64_t value) = 0;
700  virtual void bindLow(size_t idx, size_t value) = 0;
701  virtual void bindLow(size_t idx, double value) = 0;
702  virtual void bindLow(size_t idx, const std::string &value) = 0;
703  virtual void bindLow(size_t idx, const char *cstring) = 0;
704  virtual void bindLow(size_t idx, Nothing) = 0;
705  virtual void bindLow(size_t idx, const std::vector<uint8_t> &data) = 0;
706 
707  Iterator makeIterator() {
708  return Iterator(shared_from_this());
709  }
710 
711  // Begin execution of a statement in the READY state. If the statement is in the FINISHED or EXECUTING state it will be
712  // restarted.
713  Iterator begin() {
714  if (!connection())
715  throw Exception("connection is closed");
716  switch (state()) {
717  case Statement::DEAD:
718  throw Exception("statement is dead");
719  case Statement::UNBOUND: {
720  std::string s;
721  for (Parameters::Node &param: params_.nodes()) {
722  if (!param.value().isBound)
723  s += (s.empty() ? "" : ", ") + param.key();
724  }
725  ASSERT_forbid(s.empty());
726  throw Exception("unbound parameters: " + s);
727  }
728  case Statement::FINISHED:
729  case Statement::EXECUTING:
730  reset(false);
731  // fall through
732  case Statement::READY: {
733  if (!lockConnection())
734  throw Exception("connection has been closed");
735  state(Statement::EXECUTING);
736  rowNumber_ = 0;
737  Iterator iter = beginLow();
738  rowNumber_ = 0; // in case beginLow changed it
739  return iter;
740  }
741  }
742  ASSERT_not_reachable("invalid state");
743  }
744 
745  // The driver-specific component of "begin". The statement is guaranteed to be in the EXECUTING state when called,
746  // but could be in some other state after returning.
747  virtual Iterator beginLow() = 0;
748 
749  // Advance an executing statement to the next row
750  Iterator next() {
751  if (!connection())
752  throw Exception("connection is closed");
753  ASSERT_require(state() == Statement::EXECUTING); // no other way to get here
754  invalidateIteratorsAndRows();
755  ++rowNumber_;
756  return nextLow();
757  }
758 
759  // Current row number
760  size_t rowNumber() const {
761  return rowNumber_;
762  }
763 
764  // The driver-specific component of "next". The statement is guaranteed to be in the EXECUTING state when called, but
765  // could be in some other state after returning.
766  virtual Iterator nextLow() = 0;
767 
768  // Get a column value from the current row of result
769  template<typename T>
770  Optional<T> get(size_t columnIdx) {
771  if (!connection())
772  throw Exception("connection is closed");
773  ASSERT_require(state() == Statement::EXECUTING); // no other way to get here
774  if (columnIdx >= nColumns())
775  throw Exception("column index " + boost::lexical_cast<std::string>(columnIdx) + " is out of range");
776  return ColumnReader<T>()(this, columnIdx);
777  }
778 
779  // Number of columns returned by a query.
780  virtual size_t nColumns() const = 0;
781 
782  // Get the value of a particular column of the current row.
783  virtual Optional<std::string> getString(size_t idx) = 0;
784  virtual Optional<std::vector<std::uint8_t>> getBlob(size_t idx) = 0;
785 };
786 
787 template<typename T>
788 inline Optional<T>
789 ColumnReader<T>::operator()(StatementBase *stmt, size_t idx) {
790  std::string str;
791  if (!stmt->getString(idx).assignTo(str))
792  return Nothing();
793  return boost::lexical_cast<T>(str);
794 }
795 
796 template<>
797 inline Optional<std::vector<uint8_t>>
798 ColumnReader<std::vector<uint8_t>>::operator()(StatementBase *stmt, size_t idx) {
799  return stmt->getBlob(idx);
800 }
801 
802 inline Statement
803 ConnectionBase::makeStatement(const std::shared_ptr<Detail::StatementBase> &detail) {
804  return Statement(detail);
805 }
806 
807 } // namespace
808 
810 // Implementations Connection
812 
813 
814 inline Connection::Connection(const std::shared_ptr<Detail::ConnectionBase> &pimpl)
815  : pimpl_(pimpl) {}
816 
817 inline bool
818 Connection::isOpen() const {
819  return pimpl_ != nullptr;
820 }
821 
822 inline Connection&
823 Connection::close() {
824  pimpl_ = nullptr;
825  return *this;
826 }
827 
828 inline std::string
829 Connection::driverName() const {
830  if (pimpl_) {
831  return pimpl_->driverName();
832  } else {
833  return "";
834  }
835 }
836 
837 inline Statement
838 Connection::stmt(const std::string &sql) {
839  if (pimpl_) {
840  return pimpl_->prepareStatement(sql);
841  } else {
842  throw Exception("no active database connection");
843  }
844 }
845 
846 inline Connection&
847 Connection::run(const std::string &sql) {
848  stmt(sql).begin();
849  return *this;
850 }
851 
852 template<typename T>
853 inline Optional<T>
854 Connection::get(const std::string &sql) {
855  for (auto row: stmt(sql))
856  return row.get<T>(0);
857  return Nothing();
858 }
859 
860 inline size_t
861 Connection::lastInsert() const {
862  if (pimpl_) {
863  return pimpl_->lastInsert();
864  } else {
865  throw Exception("no active database connection");
866  }
867 }
868 
870 // Implementations for Statement
872 
873 inline Connection
874 Statement::connection() const {
875  if (pimpl_) {
876  return Connection(pimpl_->connection());
877  } else {
878  return Connection();
879  }
880 }
881 
882 template<typename T>
883 inline Statement&
884 Statement::bind(const std::string &name, const T &value) {
885  if (pimpl_) {
886  pimpl_->bind(name, value, false);
887  } else {
888  throw Exception("no active database connection");
889  }
890  return *this;
891 }
892 
893 template<typename T>
894 inline Statement&
895 Statement::rebind(const std::string &name, const T &value) {
896  if (pimpl_) {
897  pimpl_->bind(name, value, true);
898  } else {
899  throw Exception("no active database connection");
900  }
901  return *this;
902 }
903 
904 inline Iterator
905 Statement::begin() {
906  if (pimpl_) {
907  return pimpl_->begin();
908  } else {
909  throw Exception("no active database connection");
910  }
911 }
912 
913 inline Iterator
914 Statement::end() {
915  return Iterator();
916 }
917 
918 inline Statement&
919 Statement::run() {
920  begin();
921  return *this;
922 }
923 
924 template<typename T>
925 inline Optional<T>
926 Statement::get() {
927  Iterator row = begin();
928  if (row.isEnd())
929  throw Exception("query did not return a row");
930  return row->get<T>(0);
931 }
932 
934 // Implementations for Iterator
936 
937 inline
938 Iterator::Iterator(const std::shared_ptr<Detail::StatementBase> &stmt)
939  : row_(stmt) {}
940 
941 inline const Row&
942 Iterator::dereference() const {
943  if (isEnd())
944  throw Exception("dereferencing the end iterator");
945  if (row_.sequence_ != row_.stmt_->sequence())
946  throw Exception("iterator has been invalidated");
947  return row_;
948 }
949 
950 inline bool
951 Iterator::equal(const Iterator &other) const {
952  return row_.stmt_ == other.row_.stmt_ && row_.sequence_ == other.row_.sequence_;
953 }
954 
955 inline void
956 Iterator::increment() {
957  if (isEnd())
958  throw Exception("incrementing the end iterator");
959  *this = row_.stmt_->next();
960 }
961 
963 // Implementations for Row
965 
966 inline
967 Row::Row(const std::shared_ptr<Detail::StatementBase> &stmt)
968  : stmt_(stmt), sequence_(stmt ? stmt->sequence() : 0) {}
969 
970 template<typename T>
971 inline Optional<T>
972 Row::get(size_t columnIdx) const {
973  ASSERT_not_null(stmt_);
974  if (sequence_ != stmt_->sequence())
975  throw Exception("row has been invalidated");
976  return stmt_->get<T>(columnIdx);
977 }
978 
979 inline size_t
980 Row::rowNumber() const {
981  ASSERT_not_null(stmt_);
982  if (sequence_ != stmt_->sequence())
983  throw Exception("row has been invalidated");
984  return stmt_->rowNumber();
985 }
986 
987 } // namespace
988 } // namespace
989 
990 #endif
991 #endif
STL namespace.
Holds a value or nothing.
Definition: Optional.h:49
bool increment(Word *vec1, const BitRange &range1)
Increment.
Name space for the entire library.
Definition: FeasiblePath.h:767
State
Decoder state.
Definition: String.h:198