16 Commits
2.0.0 ... 4.2.0

Author SHA1 Message Date
2301da8143 Add PostgreSQL FOR UPDATE getters
Add a PostgreSQL-specific getRecordForUpdate helper that appends FOR
UPDATE to the generated SELECT statement so callers can lock a row
inside an explicit transaction.

generateProcsForModels now always emits a direct-connection
get<RecordName>ForUpdate proc that accepts db_postgres.DbConn. There is
intentionally no dbType overload for this API, because reacquiring a
connection via withConnection would defeat the lock's transactional
scope.

The source docs and README now document the new helper and show the
intended usage pattern inside inTransaction:

  db.inTransaction:
    var item = conn.getTodoItemForUpdate(todoId)
    item.priority += 1
    discard conn.updateTodoItem(item)
2026-03-24 22:04:49 -05:00
71cb5a7cff Update documentation for new signature changes, bump version. 2026-03-24 21:48:51 -05:00
1a9314fe4f Add connection overloads for generated ORM procs
The generated ORM helpers currently only accept the database wrapper
type as their first argument. That works well for the common case, but
it becomes misleading inside inTransaction blocks because the generated
proc will call withConnection again and may therefore use a different
connection than the one that is participating in the transaction.

Add DbConnType-constrained overloads for the generated model CRUD/query
procs, generated lookups, and generated join-table helpers. This lets
callers explicitly use the transaction connection while keeping the
existing dbType-based API intact for non-transactional call sites.

This makes the intended transactional usage straightforward:

  db.inTransaction:
    var userRecord = conn.getUser("userId1")
    userRecord.visitCount += 1
    discard conn.updateUser(userRecord)

AI-Assisted: yes
AI-Tool: OpenAI Codes / gpt-5.4 xhigh
2026-03-24 21:39:25 -05:00
bb36bba864 Support distinct versions of types we know how to convert. 2025-09-02 00:40:00 -05:00
f54bf6e974 Add tryGet<RecordName> versions of get<Record> calls
`tryGet<RecordName>`  returns Option types rather than raise exceptions.

For example:

    generateProcsForModels(MyDb, [ User ])

will now create both:

    proc getUser*(db: MyDb, id: string): User
    proc tryGetUser*(db: MyDb, id: string): Option[User]
2025-09-02 00:36:11 -05:00
e1fa2480d0 Major update to provide thread-safe, robust connection pooling.
Taking inspiration from the waterpark library, the connection pooling
mechanism has been refactored to be thread-safe. Additionally, the
pooling logic detects and handles stale connections in the pool. When a
connection is requested from the pool, the pool first validates that it
is healthy and replaces it with a fresh connection if necessary. This is
transparent to the requester.

Additionally we refactored the internal logging implementation to make
it more conventient to access logging infrastructure and log from
various sub-scopes within fiber_orm (query, pool, etc.)
2025-07-27 17:47:07 -05:00
b8c64cc693 Migrate to namespaced_logging v2. 2025-07-12 07:54:13 -05:00
aa02f9f5b1 Add support for records associated via join tables. 2025-05-19 17:56:40 -05:00
9d1cc4bbec Cache logger instance. 2025-01-20 06:39:02 -06:00
af44d48df1 Extract pagination logic into a common, exported function. Fix PG date parsing (again). 2025-01-10 20:25:49 -06:00
2030fd4490 Use namespaced_logging 1.x for logging (optionally). 2025-01-05 02:06:57 -06:00
0599d41061 Support Nim 2.x, compatibility with waterpark.
- Nim 2.x has moved the DB connectors outside the standard library to
  the `db_connector` package.
- Refactor the pooling implementation and macro expectations to use the
  `withConnection` name instead of `withConn`. This change allows a
  caller to use a [waterpark](https://github.com/guzba/waterpark) pool
  instance instead of the builtin pool instance. Waterpark provides
  better support for multi-threaded environments. The builtin pooling
  mechanism may be deprecated in favor of waterpark in the future.
- Add the `getModelIfItExists` generated proc to the list of standard
  procs we generate. This is a flavour of `getModel` that returns an
  `Option` instead of raising an exception when there is no model for
  the given id.
- Change `PaginationParams#orderBy` to accept a `seq[string]` to allow
  for sorting on multiple fields.
2025-01-03 07:55:05 -06:00
fb74d84cb7 Map names to db ident names for columns passed for ordering in paginated queries. 2023-08-09 09:16:10 -05:00
fbd20de71f Add createOrUpdateRecord and record method generators.
`createOrUpdateRecord` implements upsert: update an existing record if
it exists or create a new record if not. A new error `DbUpdateError` was
added to be raised when an existing record does exist but was not able
to be updated.
2023-08-09 09:13:12 -05:00
540d0d2f67 Fix missing import in pooling implementation. 2023-02-04 19:04:50 -06:00
a05555ee67 WIP - Initial stab at making it generic to support db_sqlite. 2022-11-03 16:38:14 -05:00
8 changed files with 888 additions and 319 deletions

2
.gitignore vendored
View File

@@ -1,2 +1,4 @@
*.sw?
nimcache/
nimble.develop
nimble.paths

View File

@@ -57,7 +57,7 @@ Models may be defined as:
.. code-block:: Nim
# models.nim
import std/options, std/times
import std/[options, times]
import uuids
type
@@ -82,6 +82,8 @@ Using Fiber ORM we can generate a data access layer with:
.. code-block:: Nim
# db.nim
import std/[options]
import db_connector/db_postgres
import fiber_orm
import ./models.nim
@@ -98,30 +100,90 @@ Using Fiber ORM we can generate a data access layer with:
generateLookup(TodoDB, TimeEntry, @["todoItemId"])
This will generate the following procedures:
This will generate procedures like the following in two flavors:
* a `dbType` flavor that acquires a connection via `withConnection`
* a connection flavor that operates directly on an existing
`conn: D` where `D: DbConnType`
.. code-block:: Nim
proc getTodoItem*(db: TodoDB, id: UUID): TodoItem;
proc getAllTodoItems*(db: TodoDB): seq[TodoItem];
proc getTodoItem*[D: DbConnType](conn: D, id: UUID): TodoItem;
proc getTodoItemForUpdate*(conn: db_postgres.DbConn, id: UUID): TodoItem;
proc tryGetTodoItem*(db: TodoDB, id: UUID): Option[TodoItem];
proc tryGetTodoItem*[D: DbConnType](conn: D, id: UUID): Option[TodoItem];
proc getTodoItemIfItExists*(db: TodoDB, id: UUID): Option[TodoItem];
proc getTodoItemIfItExists*[D: DbConnType](
conn: D, id: UUID): Option[TodoItem];
proc getAllTodoItems*(db: TodoDB,
pagination = none[PaginationParams]()): PagedRecords[TodoItem];
proc getAllTodoItems*[D: DbConnType](conn: D,
pagination = none[PaginationParams]()): PagedRecords[TodoItem];
proc createTodoItem*(db: TodoDB, rec: TodoItem): TodoItem;
proc createTodoItem*[D: DbConnType](conn: D, rec: TodoItem): TodoItem;
proc updateTodoItem*(db: TodoDB, rec: TodoItem): bool;
proc updateTodoItem*[D: DbConnType](conn: D, rec: TodoItem): bool;
proc createOrUpdateTodoItem*(db: TodoDB, rec: TodoItem): TodoItem;
proc createOrUpdateTodoItem*[D: DbConnType](
conn: D, rec: TodoItem): TodoItem;
proc deleteTodoItem*(db: TodoDB, rec: TodoItem): bool;
proc deleteTodoItem*[D: DbConnType](conn: D, rec: TodoItem): bool;
proc deleteTodoItem*(db: TodoDB, id: UUID): bool;
proc deleteTodoItem*[D: DbConnType](conn: D, id: UUID): bool;
proc findTodoItemsWhere*(db: TodoDB, whereClause: string,
values: varargs[string, dbFormat]): seq[TodoItem];
values: varargs[string, dbFormat],
pagination = none[PaginationParams]()): PagedRecords[TodoItem];
proc findTodoItemsWhere*[D: DbConnType](conn: D, whereClause: string,
values: varargs[string, dbFormat],
pagination = none[PaginationParams]()): PagedRecords[TodoItem];
proc getTimeEntry*(db: TodoDB, id: UUID): TimeEntry;
proc getAllTimeEntries*(db: TodoDB): seq[TimeEntry];
proc getTimeEntry*[D: DbConnType](conn: D, id: UUID): TimeEntry;
proc getTimeEntryIfItExists*(db: TodoDB, id: UUID): Option[TimeEntry];
proc getTimeEntryIfItExists*[D: DbConnType](
conn: D, id: UUID): Option[TimeEntry];
proc getAllTimeEntries*(db: TodoDB,
pagination = none[PaginationParams]()): PagedRecords[TimeEntry];
proc getAllTimeEntries*[D: DbConnType](conn: D,
pagination = none[PaginationParams]()): PagedRecords[TimeEntry];
proc createTimeEntry*(db: TodoDB, rec: TimeEntry): TimeEntry;
proc createTimeEntry*[D: DbConnType](conn: D, rec: TimeEntry): TimeEntry;
proc updateTimeEntry*(db: TodoDB, rec: TimeEntry): bool;
proc updateTimeEntry*[D: DbConnType](conn: D, rec: TimeEntry): bool;
proc deleteTimeEntry*(db: TodoDB, rec: TimeEntry): bool;
proc deleteTimeEntry*[D: DbConnType](conn: D, rec: TimeEntry): bool;
proc deleteTimeEntry*(db: TodoDB, id: UUID): bool;
proc deleteTimeEntry*[D: DbConnType](conn: D, id: UUID): bool;
proc findTimeEntriesWhere*(db: TodoDB, whereClause: string,
values: varargs[string, dbFormat]): seq[TimeEntry];
values: varargs[string, dbFormat],
pagination = none[PaginationParams]()): PagedRecords[TimeEntry];
proc findTimeEntriesWhere*[D: DbConnType](conn: D, whereClause: string,
values: varargs[string, dbFormat],
pagination = none[PaginationParams]()): PagedRecords[TimeEntry];
proc findTimeEntriesByTodoItemId(db: TodoDB, todoItemId: UUID): seq[TimeEntry];
proc findTimeEntriesByTodoItemId*(db: TodoDB, todoItemId: UUID,
pagination = none[PaginationParams]()): PagedRecords[TimeEntry];
proc findTimeEntriesByTodoItemId*[D: DbConnType](
conn: D, todoItemId: UUID,
pagination = none[PaginationParams]()): PagedRecords[TimeEntry];
Use the `dbType` flavor when the caller does not already have a connection.
Use the connection flavor inside `withConnection` or `inTransaction`.
The generated `get<RecordName>ForUpdate` helper is PostgreSQL-specific and
is only available for direct PostgreSQL connections.
Warning: do not call the `dbType` flavor from inside `inTransaction`.
Those overloads call `withConnection` and may acquire a different
connection, causing the statements to execute outside the active
transaction.
.. code-block:: Nim
db.inTransaction:
var item = conn.getTodoItemForUpdate(todoId)
item.priority += 1
discard conn.updateTodoItem(item)
Object-Relational Modeling
==========================
@@ -129,11 +191,11 @@ Object-Relational Modeling
Model Class
-----------
Fiber ORM uses simple Nim `object`s and `ref object`s as model classes.
Fiber ORM uses simple Nim objects and ref objects as model classes.
Fiber ORM expects there to be one table for each model class.
Name Mapping
````````````
^^^^^^^^^^^^
Fiber ORM uses `snake_case` for database identifiers (column names, table
names, etc.) and `camelCase` for Nim identifiers. We automatically convert
model names to and from table names (`TodoItem` <-> `todo_items`), as well
@@ -164,7 +226,7 @@ procedures in the `fiber_orm/util`_ module for details.
.. _util: fiber_orm/util.html
ID Field
````````
^^^^^^^^
Fiber ORM expects every model class to have a field named `id`, with a
corresponding `id` column in the model table. This field must be either a
@@ -253,8 +315,10 @@ Many of the Fiber ORM macros expect a database object type to be passed.
In the example above the `pool.DbConnPool`_ object is used as database
object type (aliased as `TodoDB`). This is the intended usage pattern, but
anything can be passed as the database object type so long as there is a
defined `withConn` template that provides an injected `conn: DbConn` object
defined `withConnection` template that provides an injected `conn: DbConn` object
to the provided statement body.
The generated connection-flavor procedures are intended to work directly
with that `conn` value.
For example, a valid database object implementation that opens a new
connection for every request might look like this:
@@ -265,7 +329,7 @@ connection for every request might look like this:
type TodoDB* = object
connString: string
template withConn*(db: TodoDB, stmt: untyped): untyped =
template withConnection*(db: TodoDB, stmt: untyped): untyped =
let conn {.inject.} = open("", "", "", db.connString)
try: stmt
finally: close(conn)

View File

@@ -1,6 +1,6 @@
# Package
version = "2.0.0"
version = "4.2.0"
author = "Jonathan Bernard"
description = "Lightweight Postgres ORM for Nim."
license = "GPL-3.0"
@@ -11,4 +11,4 @@ srcDir = "src"
# Dependencies
requires @["nim >= 1.4.0", "uuids"]
requires "https://git.jdb-software.com/jdb/nim-namespaced-logging.git"
requires "namespaced_logging >= 2.0.2"

View File

@@ -1,6 +1,6 @@
# Fiber ORM
#
# Copyright 2019 Jonathan Bernard <jonathan@jdbernard.com>
# Copyright 2019-2024 Jonathan Bernard <jonathan@jdbernard.com>
## Lightweight ORM supporting the `Postgres`_ and `SQLite`_ databases in Nim.
## It supports a simple, opinionated model mapper to generate SQL queries based
@@ -100,38 +100,89 @@
##
## generateLookup(TodoDB, TimeEntry, @["todoItemId"])
##
## This will generate the following procedures:
## This will generate procedures like the following in two flavors:
##
## * a `dbType` flavor that acquires a connection via `withConnection`
## * a connection flavor that operates directly on an existing
## `conn: D` where `D: DbConnType`
##
## .. code-block:: Nim
## proc getTodoItem*(db: TodoDB, id: UUID): TodoItem;
##
## proc getTodoItem*[D: DbConnType](conn: D, id: UUID): TodoItem;
## proc getTodoItemForUpdate*(conn: db_postgres.DbConn, id: UUID): TodoItem;
## proc tryGetTodoItem*(db: TodoDB, id: UUID): Option[TodoItem];
## proc tryGetTodoItem*[D: DbConnType](conn: D, id: UUID): Option[TodoItem];
## proc getTodoItemIfItExists*(db: TodoDB, id: UUID): Option[TodoItem];
## proc getTodoItemIfItExists*[D: DbConnType](
## conn: D, id: UUID): Option[TodoItem];
## proc createTodoItem*(db: TodoDB, rec: TodoItem): TodoItem;
## proc createTodoItem*[D: DbConnType](conn: D, rec: TodoItem): TodoItem;
## proc updateTodoItem*(db: TodoDB, rec: TodoItem): bool;
## proc updateTodoItem*[D: DbConnType](conn: D, rec: TodoItem): bool;
## proc createOrUpdateTodoItem*(db: TodoDB, rec: TodoItem): TodoItem;
## proc createOrUpdateTodoItem*[D: DbConnType](
## conn: D, rec: TodoItem): TodoItem;
## proc deleteTodoItem*(db: TodoDB, rec: TodoItem): bool;
## proc deleteTodoItem*[D: DbConnType](conn: D, rec: TodoItem): bool;
## proc deleteTodoItem*(db: TodoDB, id: UUID): bool;
## proc deleteTodoItem*[D: DbConnType](conn: D, id: UUID): bool;
##
## proc getAllTodoItems*(db: TodoDB,
## pagination = none[PaginationParams]()): seq[TodoItem];
## pagination = none[PaginationParams]()): PagedRecords[TodoItem];
## proc getAllTodoItems*[D: DbConnType](conn: D,
## pagination = none[PaginationParams]()): PagedRecords[TodoItem];
##
## proc findTodoItemsWhere*(db: TodoDB, whereClause: string,
## values: varargs[string, dbFormat], pagination = none[PaginationParams]()
## ): seq[TodoItem];
## ): PagedRecords[TodoItem];
## proc findTodoItemsWhere*[D: DbConnType](conn: D, whereClause: string,
## values: varargs[string, dbFormat], pagination = none[PaginationParams]()
## ): PagedRecords[TodoItem];
##
## proc getTimeEntry*(db: TodoDB, id: UUID): TimeEntry;
## proc getTimeEntry*[D: DbConnType](conn: D, id: UUID): TimeEntry;
## proc createTimeEntry*(db: TodoDB, rec: TimeEntry): TimeEntry;
## proc createTimeEntry*[D: DbConnType](conn: D, rec: TimeEntry): TimeEntry;
## proc updateTimeEntry*(db: TodoDB, rec: TimeEntry): bool;
## proc updateTimeEntry*[D: DbConnType](conn: D, rec: TimeEntry): bool;
## proc deleteTimeEntry*(db: TodoDB, rec: TimeEntry): bool;
## proc deleteTimeEntry*[D: DbConnType](conn: D, rec: TimeEntry): bool;
## proc deleteTimeEntry*(db: TodoDB, id: UUID): bool;
## proc deleteTimeEntry*[D: DbConnType](conn: D, id: UUID): bool;
##
## proc getAllTimeEntries*(db: TodoDB,
## pagination = none[PaginationParams]()): seq[TimeEntry];
## pagination = none[PaginationParams]()): PagedRecords[TimeEntry];
## proc getAllTimeEntries*[D: DbConnType](conn: D,
## pagination = none[PaginationParams]()): PagedRecords[TimeEntry];
##
## proc findTimeEntriesWhere*(db: TodoDB, whereClause: string,
## values: varargs[string, dbFormat], pagination = none[PaginationParams]()
## ): seq[TimeEntry];
## ): PagedRecords[TimeEntry];
## proc findTimeEntriesWhere*[D: DbConnType](conn: D, whereClause: string,
## values: varargs[string, dbFormat], pagination = none[PaginationParams]()
## ): PagedRecords[TimeEntry];
##
## proc findTimeEntriesByTodoItemId(db: TodoDB, todoItemId: UUID,
## pagination = none[PaginationParams]()): seq[TimeEntry];
## proc findTimeEntriesByTodoItemId*(db: TodoDB, todoItemId: UUID,
## pagination = none[PaginationParams]()): PagedRecords[TimeEntry];
## proc findTimeEntriesByTodoItemId*[D: DbConnType](
## conn: D, todoItemId: UUID,
## pagination = none[PaginationParams]()): PagedRecords[TimeEntry];
##
## Use the `dbType` flavor when the caller does not already have a connection.
## Use the connection flavor inside `withConnection` or `inTransaction`.
## The generated `get<RecordName>ForUpdate` helper is PostgreSQL-specific and
## is only available for direct PostgreSQL connections.
##
## Warning: do not call the `dbType` flavor from inside `inTransaction`.
## Those overloads call `withConnection` and may acquire a different
## connection, causing the statements to execute outside the active
## transaction.
##
## .. code-block:: Nim
## db.inTransaction:
## var item = conn.getTodoItemForUpdate(todoId)
## item.priority += 1
## discard conn.updateTodoItem(item)
##
## Object-Relational Modeling
## ==========================
@@ -139,11 +190,11 @@
## Model Class
## -----------
##
## Fiber ORM uses simple Nim `object`s and `ref object`s as model classes.
## Fiber ORM uses simple Nim objects and ref objects as model classes.
## Fiber ORM expects there to be one table for each model class.
##
## Name Mapping
## ````````````
## ^^^^^^^^^^^^
## Fiber ORM uses `snake_case` for database identifiers (column names, table
## names, etc.) and `camelCase` for Nim identifiers. We automatically convert
## model names to and from table names (`TodoItem` <-> `todo_items`), as well
@@ -174,7 +225,7 @@
## .. _util: fiber_orm/util.html
##
## ID Field
## ````````
## ^^^^^^^^
##
## Fiber ORM expects every model class to have a field named `id`, with a
## corresponding `id` column in the model table. This field must be either a
@@ -263,70 +314,60 @@
## In the example above the `pool.DbConnPool`_ object is used as database
## object type (aliased as `TodoDB`). This is the intended usage pattern, but
## anything can be passed as the database object type so long as there is a
## defined `withConn` template that provides an injected `conn: DbConn` object
## defined `withConnection` template that provides a `conn: DbConn` object
## to the provided statement body.
## The generated connection-flavor procedures are intended to work directly
## with that `conn` value.
##
## For example, a valid database object implementation that opens a new
## connection for every request might look like this:
##
## .. code-block:: Nim
## import std/db_postgres
## import db_connector/db_postgres
##
## type TodoDB* = object
## connString: string
##
## template withConn*(db: TodoDB, stmt: untyped): untyped =
## let conn {.inject.} = open("", "", "", db.connString)
## template withConnection*(db: TodoDB, stmt: untyped): untyped =
## block:
## let conn = open("", "", "", db.connString)
## try: stmt
## finally: close(conn)
##
## .. _pool.DbConnPool: fiber_orm/pool.html#DbConnPool
##
import std/db_postgres, std/macros, std/options, std/sequtils, std/strutils
import namespaced_logging, uuids
import std/[json, macros, options, sequtils, strutils]
import db_connector/[db_common, db_postgres]
import uuids
from std/unicode import capitalize
import ./fiber_orm/pool
import ./fiber_orm/util
import ./fiber_orm/db_common as fiber_db_common
import ./fiber_orm/[pool, util]
import ./fiber_orm/private/logging
export
pool,
util.columnNamesForModel,
util.dbFormat,
util.dbNameToIdent,
util.identNameToDb,
util.modelName,
util.rowToModel,
util.tableName
export pool, util
export logging.enableDbLogging
type
PaginationParams* = object
pageSize*: int
offset*: int
orderBy*: Option[string]
PagedRecords*[T] = object
pagination*: Option[PaginationParams]
records*: seq[T]
totalRecords*: int
NotFoundError* = object of CatchableError ##\
DbUpdateError* = object of CatchableError
## Error types raised when a DB modification fails.
NotFoundError* = object of CatchableError
## Error type raised when no record matches a given ID
var logNs {.threadvar.}: LoggingNamespace
template log(): untyped =
if logNs.isNil: logNs = initLoggingNamespace(name = "fiber_orm", level = lvlNotice)
logNs
proc newMutateClauses(): MutateClauses =
return MutateClauses(
columns: @[],
placeholders: @[],
values: @[])
proc createRecord*[T](db: DbConn, rec: T): T =
proc createRecord*[D: DbConnType, T](db: D, rec: T): T =
## Create a new record. `rec` is expected to be a `model class`_. The `id`
## field is only set if it is non-empty (see `ID Field`_ for details).
##
@@ -344,12 +385,13 @@ proc createRecord*[T](db: DbConn, rec: T): T =
" VALUES (" & mc.placeholders.join(",") & ") " &
" RETURNING " & columnNamesForModel(rec).join(",")
log().debug "createRecord: [" & sqlStmt & "]"
logQuery("createRecord", sqlStmt)
let newRow = db.getRow(sql(sqlStmt), mc.values)
result = rowToModel(T, newRow)
proc updateRecord*[T](db: DbConn, rec: T): bool =
proc updateRecord*[D: DbConnType, T](db: D, rec: T): bool =
## Update a record by id. `rec` is expected to be a `model class`_.
var mc = newMutateClauses()
populateMutateClauses(rec, false, mc)
@@ -360,33 +402,51 @@ proc updateRecord*[T](db: DbConn, rec: T): bool =
" SET " & setClause &
" WHERE id = ? "
log().debug "updateRecord: [" & sqlStmt & "] id: " & $rec.id
logQuery("updateRecord", sqlStmt, [("id", $rec.id)])
let numRowsUpdated = db.execAffectedRows(sql(sqlStmt), mc.values.concat(@[$rec.id]))
return numRowsUpdated > 0;
template deleteRecord*(db: DbConn, modelType: type, id: typed): untyped =
proc createOrUpdateRecord*[D: DbConnType, T](db: D, rec: T): T =
## Create or update a record. `rec` is expected to be a `model class`_. If
## the `id` field is unset, or if there is no existing record with the given
## id, a new record is inserted. Otherwise, the existing record is updated.
##
## Note that this does not perform partial updates, all fields are updated.
let findRecordStmt = "SELECT id FROM " & tableName(rec) & " WHERE id = ?"
logQuery("createOrUpdateRecord", findRecordStmt, [("id", $rec.id)])
let rows = db.getAllRows(sql(findRecordStmt), [$rec.id])
if rows.len == 0: result = createRecord(db, rec)
else:
result = rec
if not updateRecord(db, rec):
raise newException(DbUpdateError,
"unable to update " & modelName(rec) & " for id " & $rec.id)
template deleteRecord*[D: DbConnType](db: D, modelType: type, id: typed): untyped =
## Delete a record by id.
let sqlStmt = "DELETE FROM " & tableName(modelType) & " WHERE id = ?"
log().debug "deleteRecord: [" & sqlStmt & "] id: " & $id
logQuery("deleteRecord", sqlStmt, [("id", $id)])
db.tryExec(sql(sqlStmt), $id)
proc deleteRecord*[T](db: DbConn, rec: T): bool =
proc deleteRecord*[D: DbConnType, T](db: D, rec: T): bool =
## Delete a record by `id`_.
##
## .. _id: #model-class-id-field
let sqlStmt = "DELETE FROM " & tableName(rec) & " WHERE id = ?"
log().debug "deleteRecord: [" & sqlStmt & "] id: " & $rec.id
logQuery("deleteRecord", sqlStmt, [("id", $rec.id)])
return db.tryExec(sql(sqlStmt), $rec.id)
template getRecord*(db: DbConn, modelType: type, id: typed): untyped =
template getRecord*[D: DbConnType](db: D, modelType: type, id: typed): untyped =
## Fetch a record by id.
let sqlStmt =
"SELECT " & columnNamesForModel(modelType).join(",") &
" FROM " & tableName(modelType) &
" WHERE id = ?"
log().debug "getRecord: [" & sqlStmt & "] id: " & $id
logQuery("getRecord", sqlStmt, [("id", $id)])
let row = db.getRow(sql(sqlStmt), @[$id])
if allIt(row, it.len == 0):
@@ -394,8 +454,38 @@ template getRecord*(db: DbConn, modelType: type, id: typed): untyped =
rowToModel(modelType, row)
template findRecordsWhere*(
db: DbConn,
template getRecordForUpdate*(db: db_postgres.DbConn, modelType: type, id: typed): untyped =
## Fetch a record by id and lock it with `FOR UPDATE`.
##
## This is PostgreSQL-specific and should only be used inside a transaction.
let sqlStmt =
"SELECT " & columnNamesForModel(modelType).join(",") &
" FROM " & tableName(modelType) &
" WHERE id = ? FOR UPDATE"
logQuery("getRecordForUpdate", sqlStmt, [("id", $id)])
let row = db.getRow(sql(sqlStmt), @[$id])
if allIt(row, it.len == 0):
raise newException(NotFoundError, "no " & modelName(modelType) & " record for id " & $id)
rowToModel(modelType, row)
template tryGetRecord*[D: DbConnType](db: D, modelType: type, id: typed): untyped =
## Fetch a record by id.
let sqlStmt =
"SELECT " & columnNamesForModel(modelType).join(",") &
" FROM " & tableName(modelType) &
" WHERE id = ?"
logQuery("tryGetRecord", sqlStmt, [("id", $id)])
let row = db.getRow(sql(sqlStmt), @[$id])
if allIt(row, it.len == 0): none[modelType]()
else: some(rowToModel(modelType, row))
template findRecordsWhere*[D: DbConnType](
db: D,
modelType: type,
whereClause: string,
values: varargs[string, dbFormat],
@@ -412,17 +502,9 @@ template findRecordsWhere*(
"SELECT COUNT(*) FROM " & tableName(modelType) &
" WHERE " & whereClause
if page.isSome:
let p = page.get
if p.orderBy.isSome:
fetchStmt &= " ORDER BY " & p.orderBy.get
else:
fetchStmt &= " ORDER BY id"
if page.isSome: fetchStmt &= getPagingClause(page.get)
fetchStmt &= " LIMIT " & $p.pageSize &
" OFFSET " & $p.offset
log().debug "findRecordsWhere: [" & fetchStmt & "] values: (" & values.join(", ") & ")"
logQuery("findRecordsWhere", fetchStmt, [("values", values.join(", "))])
let records = db.getAllRows(sql(fetchStmt), values).mapIt(rowToModel(modelType, it))
PagedRecords[modelType](
@@ -432,8 +514,8 @@ template findRecordsWhere*(
if page.isNone: records.len
else: db.getRow(sql(countStmt), values)[0].parseInt)
template getAllRecords*(
db: DbConn,
template getAllRecords*[D: DbConnType](
db: D,
modelType: type,
page: Option[PaginationParams]): untyped =
## Fetch all records of the given type.
@@ -443,17 +525,9 @@ template getAllRecords*(
var countStmt = "SELECT COUNT(*) FROM " & tableName(modelType)
if page.isSome:
let p = page.get
if p.orderBy.isSome:
fetchStmt &= " ORDER BY " & p.orderBy.get
else:
fetchStmt &= " ORDER BY id"
if page.isSome: fetchStmt &= getPagingClause(page.get)
fetchStmt &= " LIMIT " & $p.pageSize &
" OFFSET " & $p.offset
log().debug "getAllRecords: [" & fetchStmt & "]"
logQuery("getAllRecords", fetchStmt)
let records = db.getAllRows(sql(fetchStmt)).mapIt(rowToModel(modelType, it))
PagedRecords[modelType](
@@ -464,8 +538,8 @@ template getAllRecords*(
else: db.getRow(sql(countStmt))[0].parseInt)
template findRecordsBy*(
db: DbConn,
template findRecordsBy*[D: DbConnType](
db: D,
modelType: type,
lookups: seq[tuple[field: string, value: string]],
page: Option[PaginationParams]): untyped =
@@ -482,17 +556,9 @@ template findRecordsBy*(
"SELECT COUNT(*) FROM " & tableName(modelType) &
" WHERE " & whereClause
if page.isSome:
let p = page.get
if p.orderBy.isSome:
fetchStmt &= " ORDER BY " & p.orderBy.get
else:
fetchStmt &= " ORDER BY id"
if page.isSome: fetchStmt &= getPagingClause(page.get)
fetchStmt &= " LIMIT " & $p.pageSize &
" OFFSET " & $p.offset
log().debug "findRecordsBy: [" & fetchStmt & "] values (" & values.join(", ") & ")"
logQuery("findRecordsBy", fetchStmt, [("values", values.join(", "))])
let records = db.getAllRows(sql(fetchStmt), values).mapIt(rowToModel(modelType, it))
PagedRecords[modelType](
@@ -503,63 +569,235 @@ template findRecordsBy*(
else: db.getRow(sql(countStmt), values)[0].parseInt)
template associate*[D: DbConnType, I, J](
db: D,
joinTableName: string,
rec1: I,
rec2: J): void =
## Associate two records via a join table.
let insertStmt =
"INSERT INTO " & joinTableName &
" (" & tableName(I) & "_id, " & tableName(J) & "_id) " &
" VALUES (?, ?)"
logQuery("associate", insertStmt, [("id1", $rec1.id), ("id2", $rec2.id)])
db.exec(sql(insertStmt), [$rec1.id, $rec2.id])
template findViaJoinTable*[D: DbConnType, L](
db: D,
joinTableName: string,
targetType: type,
rec: L,
page: Option[PaginationParams]): untyped =
## Find all records of `targetType` that are associated with `rec` via a
## join table.
let columns = columnNamesForModel(targetType).mapIt("t." & it).join(",")
var fetchStmt =
"SELECT " & columns &
" FROM " & tableName(targetType) & " AS t " &
" JOIN " & joinTableName & " AS j " &
" ON t.id = jt." & tableName(targetType) & "_id " &
" WHERE jt." & tableName(rec) & "_id = ?"
var countStmt =
"SELECT COUNT(*) FROM " & joinTableName &
" WHERE " & tableName(rec) & "_id = ?"
if page.isSome: fetchStmt &= getPagingClause(page.get)
logQuery("findViaJoinTable", fetchStmt, [("id", $rec.id)])
let records = db.getAllRows(sql(fetchStmt), $rec.id)
.mapIt(rowToModel(targetType, it))
PagedRecords[targetType](
pagination: page,
records: records,
totalRecords:
if page.isNone: records.len
else: db.getRow(sql(countStmt))[0].parseInt)
template findViaJoinTable*[D: DbConnType](
db: D,
joinTableName: string,
targetType: type,
lookupType: type,
id: typed,
page: Option[PaginationParams]): untyped =
## Find all records of `targetType` that are associated with a record of
## `lookupType` via a join table.
let columns = columnNamesForModel(targetType).mapIt("t." & it).join(",")
var fetchStmt =
"SELECT " & columns &
" FROM " & tableName(targetType) & " AS t " &
" JOIN " & joinTableName & " AS j " &
" ON t.id = jt." & tableName(targetType) & "_id " &
" WHERE jt." & tableName(lookupType) & "_id = ?"
var countStmt =
"SELECT COUNT(*) FROM " & joinTableName &
" WHERE " & tableName(lookupType) & "_id = ?"
if page.isSome: fetchStmt &= getPagingClause(page.get)
logQuery("findViaJoinTable", fetchStmt, [("id", $id)])
let records = db.getAllRows(sql(fetchStmt), $id)
.mapIt(rowToModel(targetType, it))
PagedRecords[targetType](
pagination: page,
records: records,
totalRecords:
if page.isNone: records.len
else: db.getRow(sql(countStmt))[0].parseInt)
macro generateProcsForModels*(dbType: type, modelTypes: openarray[type]): untyped =
## Generate all standard access procedures for the given model types. For a
## `model class`_ named `TodoItem`, this will generate the following
## procedures:
## `model class`_ named `TodoItem`, this will generate `dbType` and
## connection overloads for procedures like the following:
##
## .. code-block:: Nim
## proc getTodoItem*(db: TodoDB, id: idType): TodoItem;
## proc getAllTodoItems*(db: TodoDB): TodoItem;
## proc getTodoItem*[D: DbConnType](conn: D, id: idType): TodoItem;
## proc getTodoItemForUpdate*(conn: db_postgres.DbConn, id: idType): TodoItem;
## proc tryGetTodoItem*(db: TodoDB, id: idType): Option[TodoItem];
## proc tryGetTodoItem*[D: DbConnType](conn: D, id: idType): Option[TodoItem];
## proc getTodoItemIfItExists*(db: TodoDB, id: idType): Option[TodoItem];
## proc getTodoItemIfItExists*[D: DbConnType](
## conn: D, id: idType): Option[TodoItem];
## proc getAllTodoItems*(db: TodoDB): PagedRecords[TodoItem];
## proc getAllTodoItems*[D: DbConnType](conn: D): PagedRecords[TodoItem];
## proc createTodoItem*(db: TodoDB, rec: TodoItem): TodoItem;
## proc createTodoItem*[D: DbConnType](conn: D, rec: TodoItem): TodoItem;
## proc deleteTodoItem*(db: TodoDB, rec: TodoItem): bool;
## proc deleteTodoItem*[D: DbConnType](conn: D, rec: TodoItem): bool;
## proc deleteTodoItem*(db: TodoDB, id: idType): bool;
## proc deleteTodoItem*[D: DbConnType](conn: D, id: idType): bool;
## proc updateTodoItem*(db: TodoDB, rec: TodoItem): bool;
## proc updateTodoItem*[D: DbConnType](conn: D, rec: TodoItem): bool;
## proc createOrUpdateTodoItem*(db: TodoDB, rec: TodoItem): TodoItem;
## proc createOrUpdateTodoItem*[D: DbConnType](
## conn: D, rec: TodoItem): TodoItem;
##
## proc findTodoItemsWhere*(
## db: TodoDB, whereClause: string, values: varargs[string]): TodoItem;
## db: TodoDB,
## whereClause: string,
## values: varargs[string, dbFormat],
## pagination = none[PaginationParams]()): PagedRecords[TodoItem];
## proc findTodoItemsWhere*[D: DbConnType](
## conn: D,
## whereClause: string,
## values: varargs[string, dbFormat],
## pagination = none[PaginationParams]()): PagedRecords[TodoItem];
##
## `dbType` is expected to be some type that has a defined `withConn`
## `dbType` is expected to be some type that has a defined `withConnection`
## procedure (see `Database Object`_ for details).
## The `dbType` overloads are convenience wrappers around `withConnection`.
## Inside `inTransaction`, prefer the overloads that take `conn: D` where
## `D: DbConnType` so all operations use the transaction connection.
## The generated `get<RecordName>ForUpdate` helper is PostgreSQL-specific and
## is only available for direct PostgreSQL connections.
##
## .. _Database Object: #database-object
result = newStmtList()
for t in modelTypes:
if t.getType[1].typeKind == ntyRef:
raise newException(ValueError,
"fiber_orm model object must be objects, not refs")
let modelName = $(t.getType[1])
let getName = ident("get" & modelName)
let getForUpdateName = ident("get" & modelName & "ForUpdate")
let tryGetName = ident("tryGet" & modelName)
let getIfExistsName = ident("get" & modelName & "IfItExists")
let getAllName = ident("getAll" & pluralize(modelName))
let findWhereName = ident("find" & pluralize(modelName) & "Where")
let createName = ident("create" & modelName)
let updateName = ident("update" & modelName)
let createOrUpdateName = ident("createOrUpdate" & modelName)
let deleteName = ident("delete" & modelName)
let idType = typeOfColumn(t, "id")
result.add quote do:
proc `getName`*(db: `dbType`, id: `idType`): `t` =
db.withConn: result = getRecord(conn, `t`, id)
db.withConnection conn: result = getRecord(conn, `t`, id)
proc `getName`*[D: DbConnType](conn: D, id: `idType`): `t` =
result = getRecord(conn, `t`, id)
proc `getForUpdateName`*(conn: db_postgres.DbConn, id: `idType`): `t` =
result = getRecordForUpdate(conn, `t`, id)
proc `tryGetName`*(db: `dbType`, id: `idType`): Option[`t`] =
db.withConnection conn: result = tryGetRecord(conn, `t`, id)
proc `tryGetName`*[D: DbConnType](conn: D, id: `idType`): Option[`t`] =
result = tryGetRecord(conn, `t`, id)
proc `getIfExistsName`*(db: `dbType`, id: `idType`): Option[`t`] =
db.withConnection conn:
try: result = some(getRecord(conn, `t`, id))
except NotFoundError: result = none[`t`]()
proc `getIfExistsName`*[D: DbConnType](conn: D, id: `idType`): Option[`t`] =
try: result = some(getRecord(conn, `t`, id))
except NotFoundError: result = none[`t`]()
proc `getAllName`*(db: `dbType`, pagination = none[PaginationParams]()): PagedRecords[`t`] =
db.withConn: result = getAllRecords(conn, `t`, pagination)
db.withConnection conn: result = getAllRecords(conn, `t`, pagination)
proc `getAllName`*[D: DbConnType](
conn: D,
pagination = none[PaginationParams]()): PagedRecords[`t`] =
result = getAllRecords(conn, `t`, pagination)
proc `findWhereName`*(
db: `dbType`,
whereClause: string,
values: varargs[string, dbFormat],
pagination = none[PaginationParams]()): PagedRecords[`t`] =
db.withConn:
db.withConnection conn:
result = findRecordsWhere(conn, `t`, whereClause, values, pagination)
proc `findWhereName`*[D: DbConnType](
conn: D,
whereClause: string,
values: varargs[string, dbFormat],
pagination = none[PaginationParams]()): PagedRecords[`t`] =
result = findRecordsWhere(conn, `t`, whereClause, values, pagination)
proc `createName`*(db: `dbType`, rec: `t`): `t` =
db.withConn: result = createRecord(conn, rec)
db.withConnection conn: result = createRecord(conn, rec)
proc `createName`*[D: DbConnType](conn: D, rec: `t`): `t` =
result = createRecord(conn, rec)
proc `updateName`*(db: `dbType`, rec: `t`): bool =
db.withConn: result = updateRecord(conn, rec)
db.withConnection conn: result = updateRecord(conn, rec)
proc `updateName`*[D: DbConnType](conn: D, rec: `t`): bool =
result = updateRecord(conn, rec)
proc `createOrUpdateName`*(db: `dbType`, rec: `t`): `t` =
db.inTransaction: result = createOrUpdateRecord(conn, rec)
proc `createOrUpdateName`*[D: DbConnType](conn: D, rec: `t`): `t` =
result = createOrUpdateRecord(conn, rec)
proc `deleteName`*(db: `dbType`, rec: `t`): bool =
db.withConn: result = deleteRecord(conn, rec)
db.withConnection conn: result = deleteRecord(conn, rec)
proc `deleteName`*[D: DbConnType](conn: D, rec: `t`): bool =
result = deleteRecord(conn, rec)
proc `deleteName`*(db: `dbType`, id: `idType`): bool =
db.withConn: result = deleteRecord(conn, `t`, id)
db.withConnection conn: result = deleteRecord(conn, `t`, id)
proc `deleteName`*[D: DbConnType](conn: D, id: `idType`): bool =
result = deleteRecord(conn, `t`, id)
macro generateLookup*(dbType: type, modelType: type, fields: seq[string]): untyped =
## Create a lookup procedure for a given set of field names. For example,
@@ -572,42 +810,49 @@ macro generateLookup*(dbType: type, modelType: type, fields: seq[string]): untyp
##
## .. code-block:: Nim
## proc findTodoItemsByOwnerAndPriority*(db: SampleDB,
## owner: string, priority: int): seq[TodoItem]
## owner: string, priority: int,
## pagination = none[PaginationParams]()): PagedRecords[TodoItem]
## proc findTodoItemsByOwnerAndPriority*[D: DbConnType](conn: D,
## owner: string, priority: int,
## pagination = none[PaginationParams]()): PagedRecords[TodoItem]
##
## Use the `db` overload for standalone calls and the `conn` overload inside
## `withConnection` or `inTransaction`.
let fieldNames = fields[1].mapIt($it)
let procName = ident("find" & pluralize($modelType.getType[1]) & "By" & fieldNames.mapIt(it.capitalize).join("And"))
# Create proc skeleton
result = quote do:
proc `procName`*(db: `dbType`): PagedRecords[`modelType`] =
db.withConn: result = findRecordsBy(conn, `modelType`)
var callParams = quote do: @[]
# Add dynamic parameters for the proc definition and inner proc call
# Add dynamic parameters for the generated proc and inner proc call.
for n in fieldNames:
let paramTuple = newNimNode(nnkPar)
paramTuple.add(newColonExpr(ident("field"), newLit(identNameToDb(n))))
paramTuple.add(newColonExpr(ident("value"), ident(n)))
# Add the parameter to the outer call (the generated proc)
# result[3] is ProcDef -> [3]: FormalParams
result[3].add(newIdentDefs(ident(n), ident("string")))
# Build up the AST for the inner procedure call
callParams[1].add(paramTuple)
# Add the optional pagination parameters to the generated proc definition
result[3].add(newIdentDefs(
let dbProcDefAST = quote do:
proc `procName`*(db: `dbType`): PagedRecords[`modelType`] =
db.withConnection conn:
result = findRecordsBy(conn, `modelType`, `callParams`, pagination)
let connProcDefAST = quote do:
proc `procName`*[D: DbConnType](conn: D): PagedRecords[`modelType`] =
result = findRecordsBy(conn, `modelType`, `callParams.copyNimTree`, pagination)
for n in fieldNames:
dbProcDefAST[3].add(newIdentDefs(ident(n), ident("string")))
connProcDefAST[3].add(newIdentDefs(ident(n), ident("string")))
dbProcDefAST[3].add(newIdentDefs(
ident("pagination"), newEmptyNode(),
quote do: none[PaginationParams]()))
# Add the call params to the inner procedure call
# result[6][0][1][0][1] is
# ProcDef -> [6]: StmtList (body) -> [0]: Call ->
# [1]: StmtList (withConn body) -> [0]: Asgn (result =) ->
# [1]: Call (inner findRecords invocation)
result[6][0][1][0][1].add(callParams)
result[6][0][1][0][1].add(quote do: pagination)
connProcDefAST[3].add(newIdentDefs(
ident("pagination"), newEmptyNode(),
quote do: none[PaginationParams]()))
result = newStmtList()
result.add dbProcDefAST
result.add connProcDefAST
macro generateProcsForFieldLookups*(dbType: type, modelsAndFields: openarray[tuple[t: type, fields: seq[string]]]): untyped =
result = newStmtList()
@@ -617,63 +862,213 @@ macro generateProcsForFieldLookups*(dbType: type, modelsAndFields: openarray[tup
let fieldNames = i[1][1][1].mapIt($it)
let procName = ident("find" & $modelType & "sBy" & fieldNames.mapIt(it.capitalize).join("And"))
# Create proc skeleton
let procDefAST = quote do:
proc `procName`*(db: `dbType`): PagedRecords[`modelType`] =
db.withConn: result = findRecordsBy(conn, `modelType`)
var callParams = quote do: @[]
# Add dynamic parameters for the proc definition and inner proc call
# Add dynamic parameters for the generated proc and inner proc call.
for n in fieldNames:
let paramTuple = newNimNode(nnkPar)
paramTuple.add(newColonExpr(ident("field"), newLit(identNameToDb(n))))
paramTuple.add(newColonExpr(ident("value"), ident(n)))
procDefAST[3].add(newIdentDefs(ident(n), ident("string")))
callParams[1].add(paramTuple)
# Add the optional pagination parameters to the generated proc definition
procDefAST[3].add(newIdentDefs(
let dbProcDefAST = quote do:
proc `procName`*(db: `dbType`): PagedRecords[`modelType`] =
db.withConnection conn:
result = findRecordsBy(conn, `modelType`, `callParams`, pagination)
let connProcDefAST = quote do:
proc `procName`*[D: DbConnType](conn: D): PagedRecords[`modelType`] =
result = findRecordsBy(conn, `modelType`, `callParams.copyNimTree`, pagination)
for n in fieldNames:
dbProcDefAST[3].add(newIdentDefs(ident(n), ident("string")))
connProcDefAST[3].add(newIdentDefs(ident(n), ident("string")))
dbProcDefAST[3].add(newIdentDefs(
ident("pagination"), newEmptyNode(),
quote do: none[PaginationParams]()))
procDefAST[6][0][1][0][1].add(callParams)
procDefAST[6][0][1][0][1].add(quote do: pagination)
connProcDefAST[3].add(newIdentDefs(
ident("pagination"), newEmptyNode(),
quote do: none[PaginationParams]()))
result.add procDefAST
result.add dbProcDefAST
result.add connProcDefAST
proc initPool*(
connect: proc(): DbConn,
poolSize = 10,
hardCap = false,
healthCheckQuery = "SELECT 'true' AS alive"): DbConnPool =
## Initialize a new DbConnPool. See the `initDb` procedure in the `Example
## Fiber ORM Usage`_ for an example
macro generateJoinTableProcs*(
dbType, model1Type, model2Type: type,
joinTableName: string): untyped =
## Generate lookup procedures for a pair of models with a join table. For
## example, given the TODO database demonstrated above, where `TodoItem` and
## `TimeEntry` have a many-to-many relationship, you might have a join table
## `todo_items_time_entries` with columns `todo_item_id` and `time_entry_id`.
## This macro will generate the following procedures:
##
## * `connect` must be a factory which creates a new `DbConn`.
## * `poolSize` sets the desired capacity of the connection pool.
## * `hardCap` defaults to `false`.
## When `false`, the pool can grow beyond the configured capacity, but will
## release connections down to the its capacity (no less than `poolSize`).
## .. code-block:: Nim
## proc getTodoItemsByTimeEntry*(db: SampleDB, timeEntry: TimeEntry,
## pagination = none[PaginationParams]()): PagedRecords[TodoItem]
## proc getTodoItemsByTimeEntry*[D: DbConnType](conn: D, timeEntry: TimeEntry,
## pagination = none[PaginationParams]()): PagedRecords[TodoItem]
## proc getTimeEntriesByTodoItem*(db: SampleDB, todoItem: TodoItem,
## pagination = none[PaginationParams]()): PagedRecords[TimeEntry]
## proc getTimeEntriesByTodoItem*[D: DbConnType](conn: D, todoItem: TodoItem,
## pagination = none[PaginationParams]()): PagedRecords[TimeEntry]
##
## When `true` the pool will not create more than its configured capacity.
## It a connection is requested, none are free, and the pool is at
## capacity, this will result in an Error being raised.
## * `healthCheckQuery` should be a simple and fast SQL query that the pool
## can use to test the liveliness of pooled connections.
## `dbType` is expected to be some type that has a defined `withConnection`
## procedure (see `Database Object`_ for details).
## As with the other generated helpers, use the connection overloads when
## you are already inside `withConnection` or `inTransaction`.
##
## .. _Example Fiber ORM Usage: #basic-usage-example-fiber-orm-usage
## .. _Database Object: #database-object
result = newStmtList()
initDbConnPool(DbConnPoolConfig(
connect: connect,
poolSize: poolSize,
hardCap: hardCap,
healthCheckQuery: healthCheckQuery))
if model1Type.getType[1].typeKind == ntyRef or
model2Type.getType[1].typeKind == ntyRef:
raise newException(ValueError,
"fiber_orm model object must be objects, not refs")
template inTransaction*(db: DbConnPool, body: untyped) =
pool.withConn(db):
let model1Name = $(model1Type.getType[1])
let model2Name = $(model2Type.getType[1])
let getModel1Name = ident("get" & pluralize(model1Name) & "By" & model2Name)
let getModel2Name = ident("get" & pluralize(model2Name) & "By" & model1Name)
let id1Type = typeOfColumn(model1Type, "id")
let id2Type = typeOfColumn(model2Type, "id")
let joinTableNameNode = newStrLitNode($joinTableName)
result.add quote do:
proc `getModel1Name`*(
db: `dbType`,
id: `id2Type`,
pagination = none[PaginationParams]()): PagedRecords[`model1Type`] =
db.withConnection conn:
result = findViaJoinTable(
conn,
`joinTableNameNode`,
`model1Type`,
`model2Type`,
id,
pagination)
proc `getModel1Name`*[D: DbConnType](
conn: D,
id: `id2Type`,
pagination = none[PaginationParams]()): PagedRecords[`model1Type`] =
result = findViaJoinTable(
conn,
`joinTableNameNode`,
`model1Type`,
`model2Type`,
id,
pagination)
proc `getModel1Name`*(
db: `dbType`,
rec: `model2Type`,
pagination = none[PaginationParams]()): PagedRecords[`model1Type`] =
db.withConnection conn:
result = findViaJoinTable(
conn,
`joinTableNameNode`,
`model1Type`,
rec,
pagination)
proc `getModel1Name`*[D: DbConnType](
conn: D,
rec: `model2Type`,
pagination = none[PaginationParams]()): PagedRecords[`model1Type`] =
result = findViaJoinTable(
conn,
`joinTableNameNode`,
`model1Type`,
rec,
pagination)
proc `getModel2Name`*(
db: `dbType`,
id: `id1Type`,
pagination = none[PaginationParams]()): PagedRecords[`model2Type`] =
db.withConnection conn:
result = findViaJoinTable(
conn,
`joinTableNameNode`,
`model2Type`,
`model1Type`,
id,
pagination)
proc `getModel2Name`*[D: DbConnType](
conn: D,
id: `id1Type`,
pagination = none[PaginationParams]()): PagedRecords[`model2Type`] =
result = findViaJoinTable(
conn,
`joinTableNameNode`,
`model2Type`,
`model1Type`,
id,
pagination)
proc `getModel2Name`*(
db: `dbType`,
rec: `model1Type`,
pagination = none[PaginationParams]()): PagedRecords[`model2Type`] =
db.withConnection conn:
result = findViaJoinTable(
conn,
`joinTableNameNode`,
`model2Type`,
rec,
pagination)
proc `getModel2Name`*[D: DbConnType](
conn: D,
rec: `model1Type`,
pagination = none[PaginationParams]()): PagedRecords[`model2Type`] =
result = findViaJoinTable(
conn,
`joinTableNameNode`,
`model2Type`,
rec,
pagination)
proc associate*(
db: `dbType`,
rec1: `model1Type`,
rec2: `model2Type`): void =
db.withConnection conn:
associate(conn, `joinTableNameNode`, rec1, rec2)
proc associate*[D: DbConnType](
conn: D,
rec1: `model1Type`,
rec2: `model2Type`): void =
associate(conn, `joinTableNameNode`, rec1, rec2)
proc associate*(
db: `dbType`,
rec2: `model2Type`,
rec1: `model1Type`): void =
db.withConnection conn:
associate(conn, `joinTableNameNode`, rec1, rec2)
proc associate*[D: DbConnType](
conn: D,
rec2: `model2Type`,
rec1: `model1Type`): void =
associate(conn, `joinTableNameNode`, rec1, rec2)
template inTransaction*(db, body: untyped) =
## Execute `body` inside a transaction using a single connection bound to
## `conn`.
##
## When calling generated Fiber ORM helpers inside this block, use the
## overloads that take `conn: D` where `D: DbConnType`. Do not call the
## overloads that take the outer database object, because those call
## `withConnection` again and may acquire a different connection.
## If you need to lock a PostgreSQL row before modifying it, use the
## generated `get<RecordName>ForUpdate` helper.
db.withConnection conn:
conn.exec(sql"BEGIN TRANSACTION")
try:
body

View File

@@ -0,0 +1,3 @@
import db_connector/[db_postgres, db_sqlite]
type DbConnType* = db_postgres.DbConn or db_sqlite.DbConn

View File

@@ -4,94 +4,77 @@
## Simple database connection pooling implementation compatible with Fiber ORM.
import std/db_postgres, std/sequtils, std/strutils, std/sugar
when (NimMajor, NimMinor, NimPatch) < (2, 0, 0):
when not defined(gcArc) and not defined (gcOrc):
{.error: "fiber_orm requires either --mm:arc or --mm:orc.".}
import namespaced_logging
import std/[deques, locks, sequtils, sugar]
import db_connector/db_common
from db_connector/db_sqlite import getRow, close
from db_connector/db_postgres import getRow, close
import ./db_common as fiber_db_common
import ./private/logging
type
DbConnPoolConfig* = object
connect*: () -> DbConn ## Factory procedure to create a new DBConn
poolSize*: int ## The pool capacity.
hardCap*: bool ## Is the pool capacity a hard cap?
##
## When `false`, the pool can grow beyond the configured
## capacity, but will release connections down to the its
## capacity (no less than `poolSize`).
##
## When `true` the pool will not create more than its
## configured capacity. It a connection is requested, none
## are free, and the pool is at capacity, this will result
## in an Error being raised.
healthCheckQuery*: string ## Should be a simple and fast SQL query that the
## pool can use to test the liveliness of pooled
## connections.
DbConnPool*[D: DbConnType] = ptr DbConnPoolObj[D]
PooledDbConn = ref object
conn: DbConn
id: int
free: bool
DbConnPool* = ref object
DbConnPoolObj[D: DbConnType] = object
## Database connection pool
conns: seq[PooledDbConn]
cfg: DbConnPoolConfig
lastId: int
connect: proc (): D {.raises: [DbError].}
healthCheckQuery: SqlQuery
entries: Deque[D]
cond: Cond
lock: Lock
var logNs {.threadvar.}: LoggingNamespace
template log(): untyped =
if logNs.isNil: logNs = initLoggingNamespace(name = "fiber_orm/pool", level = lvlNotice)
logNs
proc close*[D: DbConnType](pool: DbConnPool[D]) =
## Safely close all connections and release resources for the given pool.
getLogger("pool").debug("closing connection pool")
withLock(pool.lock):
while pool.entries.len > 0: close(pool.entries.popFirst())
proc initDbConnPool*(cfg: DbConnPoolConfig): DbConnPool =
log().debug("Initializing new pool (size: " & $cfg.poolSize)
result = DbConnPool(
conns: @[],
cfg: cfg)
deinitLock(pool.lock)
deinitCond(pool.cond)
`=destroy`(pool[])
deallocShared(pool)
proc newConn(pool: DbConnPool): PooledDbConn =
log().debug("Creating a new connection to add to the pool.")
pool.lastId += 1
let conn = pool.cfg.connect()
result = PooledDbConn(
conn: conn,
id: pool.lastId,
free: true)
pool.conns.add(result)
proc maintain(pool: DbConnPool): void =
log().debug("Maintaining pool. $# connections." % [$pool.conns.len])
pool.conns.keepIf(proc (pc: PooledDbConn): bool =
if not pc.free: return true
proc newDbConnPool*[D: DbConnType](
poolSize: int,
connectFunc: proc(): D {.raises: [DbError].},
healthCheckQuery = "SELECT 1;"): DbConnPool[D] =
## Initialize a new DbConnPool. See the `initDb` procedure in the `Example
## Fiber ORM Usage`_ for an example
##
## * `connect` must be a factory which creates a new `DbConn`.
## * `poolSize` sets the desired capacity of the connection pool.
## * `healthCheckQuery` should be a simple and fast SQL query that the pool
## can use to test the liveliness of pooled connections. By default it uses
## `SELECT 1;`
##
## .. _Example Fiber ORM Usage: ../fiber_orm.html#basic-usage-example-fiber-orm-usage
result = cast[DbConnPool[D]](allocShared0(sizeof(DbConnPoolObj[D])))
initCond(result.cond)
initLock(result.lock)
result.entries = initDeque[D](poolSize)
result.connect = connectFunc
result.healthCheckQuery = sql(healthCheckQuery)
try:
discard getRow(pc.conn, sql(pool.cfg.healthCheckQuery), [])
return true
except:
try: pc.conn.close() # try to close the connection
except: discard ""
return false
)
log().debug(
"Pruned dead connections. $# connections remaining." %
[$pool.conns.len])
for _ in 0 ..< poolSize: result.entries.addLast(connectFunc())
except DbError as ex:
try: result.close()
except: discard
getLogger("pool").error(
msg = "unable to initialize connection pool",
err = ex)
raise ex
let freeConns = pool.conns.filterIt(it.free)
if pool.conns.len > pool.cfg.poolSize and freeConns.len > 0:
let numToCull = min(freeConns.len, pool.conns.len - pool.cfg.poolSize)
if numToCull > 0:
let toCull = freeConns[0..numToCull]
pool.conns.keepIf((pc) => toCull.allIt(it.id != pc.id))
for culled in toCull:
try: culled.conn.close()
except: discard ""
log().debug(
"Trimming pool size. Culled $# free connections. $# connections remaining." %
[$toCull.len, $pool.conns.len])
proc take*(pool: DbConnPool): tuple[id: int, conn: DbConn] =
proc take*[D: DbConnType](pool: DbConnPool[D]): D {.raises: [DbError], gcsafe.} =
## Request a connection from the pool. Returns a DbConn if the pool has free
## connections, or if it has the capacity to create a new connection. If the
## pool is configured with a hard capacity limit and is out of free
@@ -99,32 +82,33 @@ proc take*(pool: DbConnPool): tuple[id: int, conn: DbConn] =
##
## Connections taken must be returned via `release` when the caller is
## finished using them in order for them to be released back to the pool.
pool.maintain
let freeConns = pool.conns.filterIt(it.free)
withLock(pool.lock):
while pool.entries.len == 0: wait(pool.cond, pool.lock)
result = pool.entries.popFirst()
log().debug(
"Providing a new connection ($# currently free)." % [$freeConns.len])
# check that the connection is healthy
try: discard getRow(result, pool.healthCheckQuery, [])
except DbError:
{.gcsafe.}:
# if it's not, let's try to close it and create a new connection
try:
getLogger("pool").info(
"pooled connection failed health check, opening a new connection")
close(result)
except: discard
result = pool.connect()
let reserved =
if freeConns.len > 0: freeConns[0]
else: pool.newConn()
reserved.free = false
log().debug("Reserve connection $#" % [$reserved.id])
return (id: reserved.id, conn: reserved.conn)
proc release*(pool: DbConnPool, connId: int): void =
proc release*[D: DbConnType](pool: DbConnPool[D], conn: D) {.raises: [], gcsafe.} =
## Release a connection back to the pool.
log().debug("Reclaiming released connaction $#" % [$connId])
let foundConn = pool.conns.filterIt(it.id == connId)
if foundConn.len > 0: foundConn[0].free = true
withLock(pool.lock):
pool.entries.addLast(conn)
signal(pool.cond)
template withConn*(pool: DbConnPool, stmt: untyped): untyped =
template withConnection*[D: DbConnType](pool: DbConnPool[D], conn, stmt: untyped): untyped =
## Convenience template to provide a connection from the pool for use in a
## statement block, automatically releasing that connnection when done.
##
## The provided connection is injected as the variable `conn` in the
## statement body.
let (connId, conn {.inject.}) = take(pool)
block:
let conn = take(pool)
try: stmt
finally: release(pool, connId)
finally: release(pool, conn)

View File

@@ -0,0 +1,34 @@
import std/[json, options]
import namespaced_logging
export namespaced_logging.log
export namespaced_logging.debug
export namespaced_logging.info
export namespaced_logging.notice
export namespaced_logging.warn
export namespaced_logging.error
export namespaced_logging.fatal
var logService {.threadvar.}: Option[ThreadLocalLogService]
var logger {.threadvar.}: Option[Logger]
proc makeQueryLogEntry(
m: string,
sql: string,
args: openArray[(string, string)] = []): JsonNode =
result = %*{ "method": m, "sql": sql }
for (k, v) in args: result[k] = %v
proc logQuery*(methodName: string, sqlStmt: string, args: openArray[(string, string)] = []) =
# namespaced_logging would do this check for us, but we don't want to even
# build the log object if we're not actually logging
if logService.isNone: return
if logger.isNone: logger = logService.getLogger("fiber_orm/query")
logger.debug(makeQueryLogEntry(methodName, sqlStmt, args))
proc enableDbLogging*(svc: ThreadLocalLogService) =
logService = some(svc)
proc getLogger*(scope: string): Option[Logger] =
logService.getLogger("fiber_orm/" & scope)

View File

@@ -3,12 +3,17 @@
# Copyright 2019 Jonathan Bernard <jonathan@jdbernard.com>
## Utility methods used internally by Fiber ORM.
import json, macros, options, sequtils, strutils, times, unicode,
uuids
import std/[json, macros, options, sequtils, strutils, times, unicode]
import uuids
import nre except toSeq
import std/nre except toSeq
type
PaginationParams* = object
pageSize*: int
offset*: int
orderBy*: Option[seq[string]]
MutateClauses* = object
## Data structure to hold information about the clauses that should be
## added to a query. How these clauses are used will depend on the query.
@@ -22,9 +27,11 @@ const ISO_8601_FORMATS = @[
"yyyy-MM-dd'T'HH:mm:ssz",
"yyyy-MM-dd'T'HH:mm:sszzz",
"yyyy-MM-dd'T'HH:mm:ss'.'fffzzz",
"yyyy-MM-dd'T'HH:mm:ss'.'ffffzzz",
"yyyy-MM-dd HH:mm:ssz",
"yyyy-MM-dd HH:mm:sszzz",
"yyyy-MM-dd HH:mm:ss'.'fffzzz"
"yyyy-MM-dd HH:mm:ss'.'fffzzz",
"yyyy-MM-dd HH:mm:ss'.'ffffzzz"
]
proc parseIso8601(val: string): DateTime =
@@ -102,7 +109,7 @@ proc dbFormat*[T](list: seq[T]): string =
proc dbFormat*[T](item: T): string =
## For all other types, fall back on a defined `$` function to create a
## string version of the value we can include in an SQL query>
## string version of the value we can include in an SQL query.
return $item
type DbArrayParseState = enum
@@ -126,18 +133,20 @@ proc parsePGDatetime*(val: string): DateTime =
var correctedVal = val;
# PostgreSQL will truncate any trailing 0's in the millisecond value leading
# to values like `2020-01-01 16:42.3+00`. This cannot currently be parsed by
# the standard times format as it expects exactly three digits for
# millisecond values. So we have to detect this and pad out the millisecond
# value to 3 digits.
let PG_PARTIAL_FORMAT_REGEX = re"(\d{4}-\d{2}-\d{2}( |'T')\d{2}:\d{2}:\d{2}\.)(\d{1,2})(\S+)?"
# The Nim `times#format` function only recognizes 3-digit millisecond values
# but PostgreSQL will sometimes send 1-2 digits, truncating any trailing 0's,
# or sometimes provide more than three digits of preceision in the millisecond value leading
# to values like `2020-01-01 16:42.3+00` or `2025-01-06 00:56:00.9007+00`.
# This cannot currently be parsed by the standard times format as it expects
# exactly three digits for millisecond values. So we have to detect this and
# coerce the millisecond value to exactly 3 digits.
let PG_PARTIAL_FORMAT_REGEX = re"(\d{4}-\d{2}-\d{2}( |'T')\d{2}:\d{2}:\d{2}\.)(\d+)(\S+)?"
let match = val.match(PG_PARTIAL_FORMAT_REGEX)
if match.isSome:
let c = match.get.captures
if c.toSeq.len == 2: correctedVal = c[0] & alignLeft(c[2], 3, '0')
else: correctedVal = c[0] & alignLeft(c[2], 3, '0') & c[3]
if c.toSeq.len == 2: correctedVal = c[0] & alignLeft(c[2], 3, '0')[0..2]
else: correctedVal = c[0] & alignLeft(c[2], 3, '0')[0..2] & c[3]
var errStr = ""
@@ -146,7 +155,7 @@ proc parsePGDatetime*(val: string): DateTime =
try: return correctedVal.parse(df)
except: errStr &= "\n\t" & getCurrentExceptionMsg()
raise newException(ValueError, "Cannot parse PG date. Tried:" & errStr)
raise newException(ValueError, "Cannot parse PG date '" & correctedVal & "'. Tried:" & errStr)
proc parseDbArray*(val: string): seq[string] =
## Parse a Postgres array column into a Nim seq[string]
@@ -207,21 +216,14 @@ proc parseDbArray*(val: string): seq[string] =
if not (parseState == inQuote) and curStr.len > 0:
result.add(curStr)
proc createParseStmt*(t, value: NimNode): NimNode =
## Utility method to create the Nim cod required to parse a value coming from
func createParseStmt*(t, value: NimNode): NimNode =
## Utility method to create the Nim code required to parse a value coming from
## the a database query. This is used by functions like `rowToModel` to parse
## the dataabase columns into the Nim object fields.
#echo "Creating parse statment for ", t.treeRepr
if t.typeKind == ntyObject:
if t.getType == UUID.getType:
result = quote do: parseUUID(`value`)
elif t.getType == DateTime.getType:
result = quote do: parsePGDatetime(`value`)
elif t.getTypeInst == Option.getType:
if t.getTypeInst == Option.getType:
var innerType = t.getTypeImpl[2][0] # start at the first RecList
# If the value is a non-pointer type, there is another inner RecList
if innerType.kind == nnkRecList: innerType = innerType[0]
@@ -232,7 +234,33 @@ proc createParseStmt*(t, value: NimNode): NimNode =
if `value`.len == 0: none[`innerType`]()
else: some(`parseStmt`)
else: error "Unknown value object type: " & $t.getTypeInst
elif t.getType == UUID.getType:
result = quote do: parseUUID(`value`)
elif t.getType == DateTime.getType:
result = quote do: parsePGDatetime(`value`)
else: error "Cannot parse column with unknown object type: " & $t.getTypeInst
elif t.typeKind == ntyGenericInst:
if t.kind == nnkBracketExpr and
t.len > 0 and
t[0] == Option.getType:
var innerType = t.getTypeInst[1]
let parseStmt = createParseStmt(innerType, value)
result = quote do:
if `value`.len == 0: none[`innerType`]()
else: some(`parseStmt`)
else: error "Cannot parse column with unknown generic instance type: " & $t.getTypeInst
elif t.typeKind == ntyDistinct:
result = quote do:
block:
let tmp: `t` = `value`
tmp
elif t.typeKind == ntyRef:
@@ -240,7 +268,7 @@ proc createParseStmt*(t, value: NimNode): NimNode =
result = quote do: parseJson(`value`)
else:
error "Unknown ref type: " & $t.getTypeInst
error "Cannot parse column with unknown ref type: " & $t.getTypeInst
elif t.typeKind == ntySequence:
let innerType = t[1]
@@ -259,37 +287,81 @@ proc createParseStmt*(t, value: NimNode): NimNode =
result = quote do: parseFloat(`value`)
elif t.typeKind == ntyBool:
result = quote do: "true".startsWith(`value`.toLower)
result = quote do: "true".startsWith(`value`.toLower) or `value` == "1"
elif t.typeKind == ntyEnum:
let innerType = t.getTypeInst
result = quote do: parseEnum[`innerType`](`value`)
else:
error "Unknown value type: " & $t.typeKind
error "Cannot parse column with unknown value type: " & $t.typeKind
func fields(t: NimNode): seq[tuple[fieldIdent: NimNode, fieldType: NimNode]] =
#[
debugEcho "T: " & t.treeRepr
debugEcho "T.kind: " & $t.kind
debugEcho "T.typeKind: " & $t.typeKind
debugEcho "T.GET_TYPE[1]: " & t.getType[1].treeRepr
debugEcho "T.GET_TYPE[1].kind: " & $t.getType[1].kind
debugEcho "T.GET_TYPE[1].typeKind: " & $t.getType[1].typeKind
debugEcho "T.GET_TYPE: " & t.getType.treeRepr
debugEcho "T.GET_TYPE[1].GET_TYPE: " & t.getType[1].getType.treeRepr
]#
# Get the object type AST, with base object (if present) and record list.
var objDefAst: NimNode
if t.typeKind == ntyObject: objDefAst = t.getType
elif t.typeKind == ntyTypeDesc:
# In this case we have a type AST that is like:
# BracketExpr
# Sym "typeDesc"
# Sym "ModelType"
objDefAst = t.
getType[1]. # get the Sym "ModelType"
getType # get the object definition type
if objDefAst.kind != nnkObjectTy:
error ("unable to enumerate the fields for model type '$#', " &
"tried to resolve the type of the provided symbol to an object " &
"definition (nnkObjectTy) but got a '$#'.\pAST:\p$#") % [
$t, $objDefAst.kind, objDefAst.treeRepr ]
else:
error ("unable to enumerate the fields for model type '$#', " &
"expected a symbol with type ntyTypeDesc but got a '$#'.\pAST:\p$#") % [
$t, $t.typeKind, t.treeRepr ]
# At this point objDefAst should look something like:
# ObjectTy
# Empty
# Sym "BaseObject"" | Empty
# RecList
# Sym "field1"
# Sym "field2"
# ...
if objDefAst[1].kind == nnkSym:
# We have a base class symbol, let's recurse and try and resolve the fields
# for the base class
for fieldDef in objDefAst[1].fields: result.add(fieldDef)
for fieldDef in objDefAst[2].children:
# objDefAst[2] is a RecList of
# ignore AST nodes that are not field definitions
if fieldDef.kind == nnkIdentDefs: result.add((fieldDef[0], fieldDef[1]))
elif fieldDef.kind == nnkSym: result.add((fieldDef, fieldDef.getTypeInst))
else: error "unknown object field definition AST: $#" % $fieldDef.kind
template walkFieldDefs*(t: NimNode, body: untyped) =
## Iterate over every field of the given Nim object, yielding and defining
## `fieldIdent` and `fieldType`, the name of the field as a Nim Ident node
## and the type of the field as a Nim Type node respectively.
let tTypeImpl = t.getTypeImpl
for (fieldIdent {.inject.}, fieldType {.inject.}) in t.fields: body
var nodeToItr: NimNode
if tTypeImpl.typeKind == ntyObject: nodeToItr = tTypeImpl[2]
elif tTypeImpl.typeKind == ntyTypeDesc: nodeToItr = tTypeImpl.getType[1].getType[2]
else: error $t & " is not an object or type desc (it's a " & $tTypeImpl.typeKind & ")."
for fieldDef {.inject.} in nodeToItr.children:
# ignore AST nodes that are not field definitions
if fieldDef.kind == nnkIdentDefs:
let fieldIdent {.inject.} = fieldDef[0]
let fieldType {.inject.} = fieldDef[1]
body
elif fieldDef.kind == nnkSym:
let fieldIdent {.inject.} = fieldDef
let fieldType {.inject.} = fieldDef.getType
body
#[ TODO: replace walkFieldDefs with things like this:
func columnNamesForModel*(modelType: typedesc): seq[string] =
modelType.fields.mapIt(identNameToDb($it[0]))
]#
macro columnNamesForModel*(modelType: typed): seq[string] =
## Return the column names corresponding to the the fields of the given
@@ -317,6 +389,7 @@ macro rowToModel*(modelType: typed, row: seq[string]): untyped =
createParseStmt(fieldType, itemLookup)))
idx += 1
#[
macro listFields*(t: typed): untyped =
var fields: seq[tuple[n: string, t: string]] = @[]
t.walkFieldDefs:
@@ -324,6 +397,7 @@ macro listFields*(t: typed): untyped =
else: fields.add((n: $fieldIdent, t: $fieldType))
result = newLit(fields)
]#
proc typeOfColumn*(modelType: NimNode, colName: string): NimNode =
## Given a model type and a column name, return the Nim type for that column.
@@ -388,6 +462,19 @@ macro populateMutateClauses*(t: typed, newRecord: bool, mc: var MutateClauses):
`mc`.placeholders.add("?")
`mc`.values.add(dbFormat(`t`.`fieldIdent`))
proc getPagingClause*(page: PaginationParams): string =
## Given a `PaginationParams` object, return the SQL clause necessary to
## limit the number of records returned by a query.
result = ""
if page.orderBy.isSome:
let orderByClause = page.orderBy.get.map(identNameToDb).join(",")
result &= " ORDER BY " & orderByClause
else:
result &= " ORDER BY id"
result &= " LIMIT " & $page.pageSize & " OFFSET " & $page.offset
## .. _model class: ../fiber_orm.html#objectminusrelational-modeling-model-class
## .. _rules for name mapping: ../fiber_orm.html
## .. _table name: ../fiber_orm.html