Add connection overloads for generated ORM procs

The generated ORM helpers currently only accept the database wrapper
type as their first argument. That works well for the common case, but
it becomes misleading inside inTransaction blocks because the generated
proc will call withConnection again and may therefore use a different
connection than the one that is participating in the transaction.

Add DbConnType-constrained overloads for the generated model CRUD/query
procs, generated lookups, and generated join-table helpers. This lets
callers explicitly use the transaction connection while keeping the
existing dbType-based API intact for non-transactional call sites.

This makes the intended transactional usage straightforward:

  db.inTransaction:
    var userRecord = conn.getUser("userId1")
    userRecord.visitCount += 1
    discard conn.updateUser(userRecord)

AI-Assisted: yes
AI-Tool: OpenAI Codes / gpt-5.4 xhigh
This commit is contained in:
2026-03-24 21:39:13 -05:00
parent bb36bba864
commit 1a9314fe4f

View File

@@ -628,17 +628,32 @@ macro generateProcsForModels*(dbType: type, modelTypes: openarray[type]): untype
proc `getName`*(db: `dbType`, id: `idType`): `t` = proc `getName`*(db: `dbType`, id: `idType`): `t` =
db.withConnection conn: result = getRecord(conn, `t`, id) db.withConnection conn: result = getRecord(conn, `t`, id)
proc `getName`*[D: DbConnType](conn: D, id: `idType`): `t` =
result = getRecord(conn, `t`, id)
proc `tryGetName`*(db: `dbType`, id: `idType`): Option[`t`] = proc `tryGetName`*(db: `dbType`, id: `idType`): Option[`t`] =
db.withConnection conn: result = tryGetRecord(conn, `t`, id) db.withConnection conn: result = tryGetRecord(conn, `t`, id)
proc `tryGetName`*[D: DbConnType](conn: D, id: `idType`): Option[`t`] =
result = tryGetRecord(conn, `t`, id)
proc `getIfExistsName`*(db: `dbType`, id: `idType`): Option[`t`] = proc `getIfExistsName`*(db: `dbType`, id: `idType`): Option[`t`] =
db.withConnection conn: db.withConnection conn:
try: result = some(getRecord(conn, `t`, id)) try: result = some(getRecord(conn, `t`, id))
except NotFoundError: result = none[`t`]() except NotFoundError: result = none[`t`]()
proc `getIfExistsName`*[D: DbConnType](conn: D, id: `idType`): Option[`t`] =
try: result = some(getRecord(conn, `t`, id))
except NotFoundError: result = none[`t`]()
proc `getAllName`*(db: `dbType`, pagination = none[PaginationParams]()): PagedRecords[`t`] = proc `getAllName`*(db: `dbType`, pagination = none[PaginationParams]()): PagedRecords[`t`] =
db.withConnection conn: result = getAllRecords(conn, `t`, pagination) db.withConnection conn: result = getAllRecords(conn, `t`, pagination)
proc `getAllName`*[D: DbConnType](
conn: D,
pagination = none[PaginationParams]()): PagedRecords[`t`] =
result = getAllRecords(conn, `t`, pagination)
proc `findWhereName`*( proc `findWhereName`*(
db: `dbType`, db: `dbType`,
whereClause: string, whereClause: string,
@@ -647,21 +662,43 @@ macro generateProcsForModels*(dbType: type, modelTypes: openarray[type]): untype
db.withConnection conn: db.withConnection conn:
result = findRecordsWhere(conn, `t`, whereClause, values, pagination) result = findRecordsWhere(conn, `t`, whereClause, values, pagination)
proc `findWhereName`*[D: DbConnType](
conn: D,
whereClause: string,
values: varargs[string, dbFormat],
pagination = none[PaginationParams]()): PagedRecords[`t`] =
result = findRecordsWhere(conn, `t`, whereClause, values, pagination)
proc `createName`*(db: `dbType`, rec: `t`): `t` = proc `createName`*(db: `dbType`, rec: `t`): `t` =
db.withConnection conn: result = createRecord(conn, rec) db.withConnection conn: result = createRecord(conn, rec)
proc `createName`*[D: DbConnType](conn: D, rec: `t`): `t` =
result = createRecord(conn, rec)
proc `updateName`*(db: `dbType`, rec: `t`): bool = proc `updateName`*(db: `dbType`, rec: `t`): bool =
db.withConnection conn: result = updateRecord(conn, rec) db.withConnection conn: result = updateRecord(conn, rec)
proc `updateName`*[D: DbConnType](conn: D, rec: `t`): bool =
result = updateRecord(conn, rec)
proc `createOrUpdateName`*(db: `dbType`, rec: `t`): `t` = proc `createOrUpdateName`*(db: `dbType`, rec: `t`): `t` =
db.inTransaction: result = createOrUpdateRecord(conn, rec) db.inTransaction: result = createOrUpdateRecord(conn, rec)
proc `createOrUpdateName`*[D: DbConnType](conn: D, rec: `t`): `t` =
result = createOrUpdateRecord(conn, rec)
proc `deleteName`*(db: `dbType`, rec: `t`): bool = proc `deleteName`*(db: `dbType`, rec: `t`): bool =
db.withConnection conn: result = deleteRecord(conn, rec) db.withConnection conn: result = deleteRecord(conn, rec)
proc `deleteName`*[D: DbConnType](conn: D, rec: `t`): bool =
result = deleteRecord(conn, rec)
proc `deleteName`*(db: `dbType`, id: `idType`): bool = proc `deleteName`*(db: `dbType`, id: `idType`): bool =
db.withConnection conn: result = deleteRecord(conn, `t`, id) db.withConnection conn: result = deleteRecord(conn, `t`, id)
proc `deleteName`*[D: DbConnType](conn: D, id: `idType`): bool =
result = deleteRecord(conn, `t`, id)
macro generateLookup*(dbType: type, modelType: type, fields: seq[string]): untyped = macro generateLookup*(dbType: type, modelType: type, fields: seq[string]): untyped =
## Create a lookup procedure for a given set of field names. For example, ## Create a lookup procedure for a given set of field names. For example,
## given the TODO database demostrated above, ## given the TODO database demostrated above,
@@ -676,39 +713,39 @@ macro generateLookup*(dbType: type, modelType: type, fields: seq[string]): untyp
## owner: string, priority: int): seq[TodoItem] ## owner: string, priority: int): seq[TodoItem]
let fieldNames = fields[1].mapIt($it) let fieldNames = fields[1].mapIt($it)
let procName = ident("find" & pluralize($modelType.getType[1]) & "By" & fieldNames.mapIt(it.capitalize).join("And")) let procName = ident("find" & pluralize($modelType.getType[1]) & "By" & fieldNames.mapIt(it.capitalize).join("And"))
# Create proc skeleton
result = quote do:
proc `procName`*(db: `dbType`): PagedRecords[`modelType`] =
db.withConnection conn: result = findRecordsBy(conn, `modelType`)
var callParams = quote do: @[] var callParams = quote do: @[]
# Add dynamic parameters for the proc definition and inner proc call # Add dynamic parameters for the generated proc and inner proc call.
for n in fieldNames: for n in fieldNames:
let paramTuple = newNimNode(nnkPar) let paramTuple = newNimNode(nnkPar)
paramTuple.add(newColonExpr(ident("field"), newLit(identNameToDb(n)))) paramTuple.add(newColonExpr(ident("field"), newLit(identNameToDb(n))))
paramTuple.add(newColonExpr(ident("value"), ident(n))) paramTuple.add(newColonExpr(ident("value"), ident(n)))
# Add the parameter to the outer call (the generated proc)
# result[3] is ProcDef -> [3]: FormalParams
result[3].add(newIdentDefs(ident(n), ident("string")))
# Build up the AST for the inner procedure call
callParams[1].add(paramTuple) callParams[1].add(paramTuple)
# Add the optional pagination parameters to the generated proc definition let dbProcDefAST = quote do:
result[3].add(newIdentDefs( proc `procName`*(db: `dbType`): PagedRecords[`modelType`] =
db.withConnection conn:
result = findRecordsBy(conn, `modelType`, `callParams`, pagination)
let connProcDefAST = quote do:
proc `procName`*[D: DbConnType](conn: D): PagedRecords[`modelType`] =
result = findRecordsBy(conn, `modelType`, `callParams.copyNimTree`, pagination)
for n in fieldNames:
dbProcDefAST[3].add(newIdentDefs(ident(n), ident("string")))
connProcDefAST[3].add(newIdentDefs(ident(n), ident("string")))
dbProcDefAST[3].add(newIdentDefs(
ident("pagination"), newEmptyNode(), ident("pagination"), newEmptyNode(),
quote do: none[PaginationParams]())) quote do: none[PaginationParams]()))
# Add the call params to the inner procedure call connProcDefAST[3].add(newIdentDefs(
# result[6][0][1][0][1] is ident("pagination"), newEmptyNode(),
# ProcDef -> [6]: StmtList (body) -> [0]: Command -> quote do: none[PaginationParams]()))
# [2]: StmtList (withConnection body) -> [0]: Asgn (result =) ->
# [1]: Call (inner findRecords invocation) result = newStmtList()
result[6][0][2][0][1].add(callParams) result.add dbProcDefAST
result[6][0][2][0][1].add(quote do: pagination) result.add connProcDefAST
macro generateProcsForFieldLookups*(dbType: type, modelsAndFields: openarray[tuple[t: type, fields: seq[string]]]): untyped = macro generateProcsForFieldLookups*(dbType: type, modelsAndFields: openarray[tuple[t: type, fields: seq[string]]]): untyped =
result = newStmtList() result = newStmtList()
@@ -718,32 +755,38 @@ macro generateProcsForFieldLookups*(dbType: type, modelsAndFields: openarray[tup
let fieldNames = i[1][1][1].mapIt($it) let fieldNames = i[1][1][1].mapIt($it)
let procName = ident("find" & $modelType & "sBy" & fieldNames.mapIt(it.capitalize).join("And")) let procName = ident("find" & $modelType & "sBy" & fieldNames.mapIt(it.capitalize).join("And"))
# Create proc skeleton
let procDefAST = quote do:
proc `procName`*(db: `dbType`): PagedRecords[`modelType`] =
db.withConnection conn: result = findRecordsBy(conn, `modelType`)
var callParams = quote do: @[] var callParams = quote do: @[]
# Add dynamic parameters for the proc definition and inner proc call # Add dynamic parameters for the generated proc and inner proc call.
for n in fieldNames: for n in fieldNames:
let paramTuple = newNimNode(nnkPar) let paramTuple = newNimNode(nnkPar)
paramTuple.add(newColonExpr(ident("field"), newLit(identNameToDb(n)))) paramTuple.add(newColonExpr(ident("field"), newLit(identNameToDb(n))))
paramTuple.add(newColonExpr(ident("value"), ident(n))) paramTuple.add(newColonExpr(ident("value"), ident(n)))
procDefAST[3].add(newIdentDefs(ident(n), ident("string")))
callParams[1].add(paramTuple) callParams[1].add(paramTuple)
# Add the optional pagination parameters to the generated proc definition let dbProcDefAST = quote do:
procDefAST[3].add(newIdentDefs( proc `procName`*(db: `dbType`): PagedRecords[`modelType`] =
db.withConnection conn:
result = findRecordsBy(conn, `modelType`, `callParams`, pagination)
let connProcDefAST = quote do:
proc `procName`*[D: DbConnType](conn: D): PagedRecords[`modelType`] =
result = findRecordsBy(conn, `modelType`, `callParams.copyNimTree`, pagination)
for n in fieldNames:
dbProcDefAST[3].add(newIdentDefs(ident(n), ident("string")))
connProcDefAST[3].add(newIdentDefs(ident(n), ident("string")))
dbProcDefAST[3].add(newIdentDefs(
ident("pagination"), newEmptyNode(), ident("pagination"), newEmptyNode(),
quote do: none[PaginationParams]())) quote do: none[PaginationParams]()))
procDefAST[6][0][1][0][1].add(callParams) connProcDefAST[3].add(newIdentDefs(
procDefAST[6][0][1][0][1].add(quote do: pagination) ident("pagination"), newEmptyNode(),
quote do: none[PaginationParams]()))
result.add procDefAST result.add dbProcDefAST
result.add connProcDefAST
macro generateJoinTableProcs*( macro generateJoinTableProcs*(
dbType, model1Type, model2Type: type, dbType, model1Type, model2Type: type,
@@ -791,6 +834,18 @@ macro generateJoinTableProcs*(
id, id,
pagination) pagination)
proc `getModel1Name`*[D: DbConnType](
conn: D,
id: `id2Type`,
pagination = none[PaginationParams]()): PagedRecords[`model1Type`] =
result = findViaJoinTable(
conn,
`joinTableNameNode`,
`model1Type`,
`model2Type`,
id,
pagination)
proc `getModel1Name`*( proc `getModel1Name`*(
db: `dbType`, db: `dbType`,
rec: `model2Type`, rec: `model2Type`,
@@ -803,10 +858,21 @@ macro generateJoinTableProcs*(
rec, rec,
pagination) pagination)
proc `getModel1Name`*[D: DbConnType](
conn: D,
rec: `model2Type`,
pagination = none[PaginationParams]()): PagedRecords[`model1Type`] =
result = findViaJoinTable(
conn,
`joinTableNameNode`,
`model1Type`,
rec,
pagination)
proc `getModel2Name`*( proc `getModel2Name`*(
db: `dbType`, db: `dbType`,
id: `id1Type`, id: `id1Type`,
pagination = none[PaginationParams]()): Pagedrecords[`model2Type`] = pagination = none[PaginationParams]()): PagedRecords[`model2Type`] =
db.withConnection conn: db.withConnection conn:
result = findViaJoinTable( result = findViaJoinTable(
conn, conn,
@@ -816,10 +882,22 @@ macro generateJoinTableProcs*(
id, id,
pagination) pagination)
proc `getModel2Name`*[D: DbConnType](
conn: D,
id: `id1Type`,
pagination = none[PaginationParams]()): PagedRecords[`model2Type`] =
result = findViaJoinTable(
conn,
`joinTableNameNode`,
`model2Type`,
`model1Type`,
id,
pagination)
proc `getModel2Name`*( proc `getModel2Name`*(
db: `dbType`, db: `dbType`,
rec: `model1Type`, rec: `model1Type`,
pagination = none[PaginationParams]()): Pagedrecords[`model2Type`] = pagination = none[PaginationParams]()): PagedRecords[`model2Type`] =
db.withConnection conn: db.withConnection conn:
result = findViaJoinTable( result = findViaJoinTable(
conn, conn,
@@ -828,6 +906,17 @@ macro generateJoinTableProcs*(
rec, rec,
pagination) pagination)
proc `getModel2Name`*[D: DbConnType](
conn: D,
rec: `model1Type`,
pagination = none[PaginationParams]()): PagedRecords[`model2Type`] =
result = findViaJoinTable(
conn,
`joinTableNameNode`,
`model2Type`,
rec,
pagination)
proc associate*( proc associate*(
db: `dbType`, db: `dbType`,
rec1: `model1Type`, rec1: `model1Type`,
@@ -835,6 +924,12 @@ macro generateJoinTableProcs*(
db.withConnection conn: db.withConnection conn:
associate(conn, `joinTableNameNode`, rec1, rec2) associate(conn, `joinTableNameNode`, rec1, rec2)
proc associate*[D: DbConnType](
conn: D,
rec1: `model1Type`,
rec2: `model2Type`): void =
associate(conn, `joinTableNameNode`, rec1, rec2)
proc associate*( proc associate*(
db: `dbType`, db: `dbType`,
rec2: `model2Type`, rec2: `model2Type`,
@@ -842,6 +937,12 @@ macro generateJoinTableProcs*(
db.withConnection conn: db.withConnection conn:
associate(conn, `joinTableNameNode`, rec1, rec2) associate(conn, `joinTableNameNode`, rec1, rec2)
proc associate*[D: DbConnType](
conn: D,
rec2: `model2Type`,
rec1: `model1Type`): void =
associate(conn, `joinTableNameNode`, rec1, rec2)
template inTransaction*(db, body: untyped) = template inTransaction*(db, body: untyped) =
db.withConnection conn: db.withConnection conn:
conn.exec(sql"BEGIN TRANSACTION") conn.exec(sql"BEGIN TRANSACTION")