8#ifndef Sawyer_Database_H
9#define Sawyer_Database_H
11#if __cplusplus >= 201103L
13#include <boost/iterator/iterator_facade.hpp>
14#include <boost/lexical_cast.hpp>
15#include <boost/numeric/conversion/cast.hpp>
17#include <Sawyer/Assert.h>
18#include <Sawyer/Map.h>
19#include <Sawyer/Optional.h>
175 class ConnectionBase;
179class Exception:
public std::runtime_error {
181 Exception(
const std::string &what)
182 : std::runtime_error(what) {}
184 ~Exception() noexcept {}
197 friend class ::Sawyer::Database::Statement;
198 friend class ::Sawyer::Database::Detail::ConnectionBase;
200 std::shared_ptr<Detail::ConnectionBase> pimpl_;
207 explicit Connection(
const std::shared_ptr<Detail::ConnectionBase> &pimpl);
213 ~Connection() =
default;
216 static Connection fromUri(
const std::string &uri);
221 static std::string uriDocString();
247 Statement stmt(
const std::string &sql);
250 Connection& run(
const std::string &sql);
258 Optional<T>
get(
const std::string &sql);
263 std::string driverName()
const;
270 size_t lastInsert()
const;
273 void pimpl(
const std::shared_ptr<Detail::ConnectionBase> &p) {
287 friend class ::Sawyer::Database::Detail::ConnectionBase;
289 std::shared_ptr<Detail::StatementBase> pimpl_;
306 explicit Statement(
const std::shared_ptr<Detail::StatementBase> &stmt)
311 Connection connection()
const;
323 Statement& bind(
const std::string &name,
const T &value);
330 Statement& rebind(
const std::string &name,
const T &value);
366 friend class ::Sawyer::Database::Iterator;
368 std::shared_ptr<Detail::StatementBase> stmt_;
375 explicit Row(
const std::shared_ptr<Detail::StatementBase> &stmt);
380 Optional<T>
get(
size_t columnIdx)
const;
385 size_t rowNumber()
const;
398class Iterator:
public boost::iterator_facade<Iterator, const Row, boost::forward_traversal_tag> {
399 friend class ::Sawyer::Database::Detail::StatementBase;
408 explicit Iterator(
const std::shared_ptr<Detail::StatementBase> &stmt);
417 explicit operator bool()
const {
422 friend class boost::iterator_core_access;
423 const Row& dereference()
const;
424 bool equal(
const Iterator&)
const;
449class ConnectionBase:
public std::enable_shared_from_this<ConnectionBase> {
450 friend class ::Sawyer::Database::Connection;
456 virtual ~ConnectionBase() {}
460 virtual void close() = 0;
464 virtual Statement prepareStatement(
const std::string &sql) = 0;
468 virtual size_t lastInsert()
const = 0;
470 Statement makeStatement(
const std::shared_ptr<Detail::StatementBase> &detail);
472 virtual std::string driverName()
const = 0;
479 friend class ::Sawyer::Database::Detail::StatementBase;
481 std::vector<size_t> indexes;
482 bool isBound =
false;
484 void append(
size_t idx) {
485 indexes.push_back(idx);
491 friend class ::Sawyer::Database::Detail::StatementBase;
492 Optional<T> operator()(StatementBase *stmt,
size_t idx);
503class StatementBase:
public std::enable_shared_from_this<StatementBase> {
504 friend class ::Sawyer::Database::Iterator;
505 friend class ::Sawyer::Database::Row;
506 friend class ::Sawyer::Database::Statement;
507 template<
class T>
friend class ::Sawyer::Database::Detail::ColumnReader;
509 using Parameters = Container::Map<std::string, Parameter>;
511 std::shared_ptr<ConnectionBase> connection_;
512 std::weak_ptr<ConnectionBase> weakConnection_;
514 Statement::State state_ = Statement::DEAD;
515 size_t sequence_ = 0;
516 size_t rowNumber_ = 0;
519 virtual ~StatementBase() {}
522 explicit StatementBase(
const std::shared_ptr<ConnectionBase> &connection)
523 : weakConnection_(connection) {
524 ASSERT_not_null(connection);
531 std::pair<std::string, size_t> parseParameters(
const std::string &highSql) {
534 bool inString =
false;
535 size_t nLowParams = 0;
536 state(Statement::READY);
537 for (
size_t i = 0; i < highSql.size(); ++i) {
538 if (
'\'' == highSql[i]) {
539 inString = !inString;
540 lowSql += highSql[i];
541 }
else if (
'?' == highSql[i] && !inString) {
543 std::string paramName;
544 while (i+1 < highSql.size() && (::isalnum(highSql[i+1]) ||
'_' == highSql[i+1]))
545 paramName += highSql[++i];
546 if (paramName.empty())
547 throw Exception(
"invalid parameter name at character position " + boost::lexical_cast<std::string>(i));
548 Parameter ¶m = params_.insertMaybeDefault(paramName);
549 param.append(nLowParams++);
550 state(Statement::UNBOUND);
552 lowSql += highSql[i];
556 state(Statement::DEAD);
557 throw Exception(
"mismatched quotes in SQL statement");
559 return std::make_pair(lowSql, nLowParams);
563 void invalidateIteratorsAndRows() {
568 size_t sequence()
const {
574 bool lockConnection() {
575 return (connection_ = weakConnection_.lock()) !=
nullptr;
580 void unlockConnection() {
586 bool isConnectionLocked()
const {
587 return connection_ !=
nullptr;
591 std::shared_ptr<ConnectionBase> connection()
const {
592 return weakConnection_.lock();
596 Statement::State state()
const {
603 void state(Statement::State newState) {
605 case Statement::DEAD:
606 case Statement::FINISHED:
607 case Statement::UNBOUND:
608 case Statement::READY:
609 invalidateIteratorsAndRows();
612 case Statement::EXECUTING:
613 ASSERT_require(isConnectionLocked());
620 bool hasUnboundParameters()
const {
621 ASSERT_forbid(state() == Statement::DEAD);
622 for (
const Parameter ¶m: params_.values()) {
631 virtual void unbindAllParams() {
632 ASSERT_forbid(state() == Statement::DEAD);
633 for (Parameter ¶m: params_.values())
634 param.isBound = false;
635 state(params_.isEmpty() ? Statement::READY : Statement::UNBOUND);
640 virtual void reset(
bool doUnbind) {
641 ASSERT_forbid(state() == Statement::DEAD);
642 invalidateIteratorsAndRows();
646 state(hasUnboundParameters() ? Statement::UNBOUND : Statement::READY);
653 void bind(
const std::string &name,
const T &value,
bool isRebind) {
655 throw Exception(
"connection is closed");
657 case Statement::DEAD:
658 throw Exception(
"statement is dead");
659 case Statement::FINISHED:
660 case Statement::EXECUTING:
663 case Statement::READY:
664 case Statement::UNBOUND: {
665 if (!params_.exists(name))
666 throw Exception(
"no such parameter \"" + name +
"\" in statement");
667 Parameter ¶m = params_[name];
668 bool wasUnbound = !param.isBound;
669 for (
size_t idx: param.indexes) {
672 }
catch (
const Exception &e) {
673 if (param.indexes.size() > 1)
674 state(Statement::DEAD);
678 param.isBound =
true;
680 if (wasUnbound && !hasUnboundParameters())
681 state(Statement::READY);
691 bind(name, *value, isRebind);
693 bind(name, Nothing(), isRebind);
698 virtual void bindLow(
size_t idx,
int value) = 0;
699 virtual void bindLow(
size_t idx, int64_t value) = 0;
700 virtual void bindLow(
size_t idx,
size_t value) = 0;
701 virtual void bindLow(
size_t idx,
double value) = 0;
702 virtual void bindLow(
size_t idx,
const std::string &value) = 0;
703 virtual void bindLow(
size_t idx,
const char *cstring) = 0;
704 virtual void bindLow(
size_t idx, Nothing) = 0;
705 virtual void bindLow(
size_t idx,
const std::vector<uint8_t> &data) = 0;
707 Iterator makeIterator() {
708 return Iterator(shared_from_this());
715 throw Exception(
"connection is closed");
717 case Statement::DEAD:
718 throw Exception(
"statement is dead");
719 case Statement::UNBOUND: {
721 for (Parameters::Node ¶m: params_.nodes()) {
722 if (!param.value().isBound)
723 s += (s.empty() ?
"" :
", ") + param.key();
725 ASSERT_forbid(s.empty());
726 throw Exception(
"unbound parameters: " + s);
728 case Statement::FINISHED:
729 case Statement::EXECUTING:
732 case Statement::READY: {
733 if (!lockConnection())
734 throw Exception(
"connection has been closed");
735 state(Statement::EXECUTING);
737 Iterator iter = beginLow();
742 ASSERT_not_reachable(
"invalid state");
747 virtual Iterator beginLow() = 0;
752 throw Exception(
"connection is closed");
753 ASSERT_require(state() == Statement::EXECUTING);
754 invalidateIteratorsAndRows();
760 size_t rowNumber()
const {
766 virtual Iterator nextLow() = 0;
770 Optional<T>
get(
size_t columnIdx) {
772 throw Exception(
"connection is closed");
773 ASSERT_require(state() == Statement::EXECUTING);
774 if (columnIdx >= nColumns())
775 throw Exception(
"column index " + boost::lexical_cast<std::string>(columnIdx) +
" is out of range");
776 return ColumnReader<T>()(
this, columnIdx);
780 virtual size_t nColumns()
const = 0;
783 virtual Optional<std::string> getString(
size_t idx) = 0;
784 virtual Optional<std::vector<std::uint8_t>> getBlob(
size_t idx) = 0;
789ColumnReader<T>::operator()(StatementBase *stmt,
size_t idx) {
791 if (!stmt->getString(idx).assignTo(str))
793 return boost::lexical_cast<T>(str);
797inline Optional<std::vector<uint8_t>>
798ColumnReader<std::vector<uint8_t>>::operator()(StatementBase *stmt,
size_t idx) {
799 return stmt->getBlob(idx);
803ConnectionBase::makeStatement(
const std::shared_ptr<Detail::StatementBase> &detail) {
804 return Statement(detail);
814inline Connection::Connection(
const std::shared_ptr<Detail::ConnectionBase> &pimpl)
818Connection::isOpen()
const {
819 return pimpl_ !=
nullptr;
829Connection::driverName()
const {
831 return pimpl_->driverName();
838Connection::stmt(
const std::string &sql) {
840 return pimpl_->prepareStatement(sql);
842 throw Exception(
"no active database connection");
847Connection::run(
const std::string &sql) {
854Connection::get(
const std::string &sql) {
855 for (
auto row: stmt(sql))
856 return row.
get<T>(0);
861Connection::lastInsert()
const {
863 return pimpl_->lastInsert();
865 throw Exception(
"no active database connection");
874Statement::connection()
const {
876 return Connection(pimpl_->connection());
884Statement::bind(
const std::string &name,
const T &value) {
886 pimpl_->bind(name, value,
false);
888 throw Exception(
"no active database connection");
895Statement::rebind(
const std::string &name,
const T &value) {
897 pimpl_->bind(name, value,
true);
899 throw Exception(
"no active database connection");
907 return pimpl_->begin();
909 throw Exception(
"no active database connection");
927 Iterator row = begin();
929 throw Exception(
"query did not return a row");
930 return row->get<T>(0);
938Iterator::Iterator(
const std::shared_ptr<Detail::StatementBase> &stmt)
942Iterator::dereference()
const {
944 throw Exception(
"dereferencing the end iterator");
945 if (row_.sequence_ != row_.stmt_->sequence())
946 throw Exception(
"iterator has been invalidated");
951Iterator::equal(
const Iterator &other)
const {
952 return row_.stmt_ == other.row_.stmt_ && row_.sequence_ == other.row_.sequence_;
956Iterator::increment() {
958 throw Exception(
"incrementing the end iterator");
959 *
this = row_.stmt_->next();
967Row::Row(
const std::shared_ptr<Detail::StatementBase> &stmt)
968 : stmt_(stmt), sequence_(stmt ? stmt->sequence() : 0) {}
972Row::get(
size_t columnIdx)
const {
973 ASSERT_not_null(stmt_);
974 if (sequence_ != stmt_->sequence())
975 throw Exception(
"row has been invalidated");
976 return stmt_->get<T>(columnIdx);
980Row::rowNumber()
const {
981 ASSERT_not_null(stmt_);
982 if (sequence_ != stmt_->sequence())
983 throw Exception(
"row has been invalidated");
984 return stmt_->rowNumber();
Holds a value or nothing.
bool get(const Word *words, size_t idx)
Return a single bit.
bool increment(Word *vec1, const BitRange &range1)
Increment.