Add connection overloads for generated ORM procs
The generated ORM helpers currently only accept the database wrapper
type as their first argument. That works well for the common case, but
it becomes misleading inside inTransaction blocks because the generated
proc will call withConnection again and may therefore use a different
connection than the one that is participating in the transaction.
Add DbConnType-constrained overloads for the generated model CRUD/query
procs, generated lookups, and generated join-table helpers. This lets
callers explicitly use the transaction connection while keeping the
existing dbType-based API intact for non-transactional call sites.
This makes the intended transactional usage straightforward:
db.inTransaction:
var userRecord = conn.getUser("userId1")
userRecord.visitCount += 1
discard conn.updateUser(userRecord)
AI-Assisted: yes
AI-Tool: OpenAI Codes / gpt-5.4 xhigh
This commit is contained in:
@@ -628,17 +628,32 @@ macro generateProcsForModels*(dbType: type, modelTypes: openarray[type]): untype
|
|||||||
proc `getName`*(db: `dbType`, id: `idType`): `t` =
|
proc `getName`*(db: `dbType`, id: `idType`): `t` =
|
||||||
db.withConnection conn: result = getRecord(conn, `t`, id)
|
db.withConnection conn: result = getRecord(conn, `t`, id)
|
||||||
|
|
||||||
|
proc `getName`*[D: DbConnType](conn: D, id: `idType`): `t` =
|
||||||
|
result = getRecord(conn, `t`, id)
|
||||||
|
|
||||||
proc `tryGetName`*(db: `dbType`, id: `idType`): Option[`t`] =
|
proc `tryGetName`*(db: `dbType`, id: `idType`): Option[`t`] =
|
||||||
db.withConnection conn: result = tryGetRecord(conn, `t`, id)
|
db.withConnection conn: result = tryGetRecord(conn, `t`, id)
|
||||||
|
|
||||||
|
proc `tryGetName`*[D: DbConnType](conn: D, id: `idType`): Option[`t`] =
|
||||||
|
result = tryGetRecord(conn, `t`, id)
|
||||||
|
|
||||||
proc `getIfExistsName`*(db: `dbType`, id: `idType`): Option[`t`] =
|
proc `getIfExistsName`*(db: `dbType`, id: `idType`): Option[`t`] =
|
||||||
db.withConnection conn:
|
db.withConnection conn:
|
||||||
try: result = some(getRecord(conn, `t`, id))
|
try: result = some(getRecord(conn, `t`, id))
|
||||||
except NotFoundError: result = none[`t`]()
|
except NotFoundError: result = none[`t`]()
|
||||||
|
|
||||||
|
proc `getIfExistsName`*[D: DbConnType](conn: D, id: `idType`): Option[`t`] =
|
||||||
|
try: result = some(getRecord(conn, `t`, id))
|
||||||
|
except NotFoundError: result = none[`t`]()
|
||||||
|
|
||||||
proc `getAllName`*(db: `dbType`, pagination = none[PaginationParams]()): PagedRecords[`t`] =
|
proc `getAllName`*(db: `dbType`, pagination = none[PaginationParams]()): PagedRecords[`t`] =
|
||||||
db.withConnection conn: result = getAllRecords(conn, `t`, pagination)
|
db.withConnection conn: result = getAllRecords(conn, `t`, pagination)
|
||||||
|
|
||||||
|
proc `getAllName`*[D: DbConnType](
|
||||||
|
conn: D,
|
||||||
|
pagination = none[PaginationParams]()): PagedRecords[`t`] =
|
||||||
|
result = getAllRecords(conn, `t`, pagination)
|
||||||
|
|
||||||
proc `findWhereName`*(
|
proc `findWhereName`*(
|
||||||
db: `dbType`,
|
db: `dbType`,
|
||||||
whereClause: string,
|
whereClause: string,
|
||||||
@@ -647,21 +662,43 @@ macro generateProcsForModels*(dbType: type, modelTypes: openarray[type]): untype
|
|||||||
db.withConnection conn:
|
db.withConnection conn:
|
||||||
result = findRecordsWhere(conn, `t`, whereClause, values, pagination)
|
result = findRecordsWhere(conn, `t`, whereClause, values, pagination)
|
||||||
|
|
||||||
|
proc `findWhereName`*[D: DbConnType](
|
||||||
|
conn: D,
|
||||||
|
whereClause: string,
|
||||||
|
values: varargs[string, dbFormat],
|
||||||
|
pagination = none[PaginationParams]()): PagedRecords[`t`] =
|
||||||
|
result = findRecordsWhere(conn, `t`, whereClause, values, pagination)
|
||||||
|
|
||||||
proc `createName`*(db: `dbType`, rec: `t`): `t` =
|
proc `createName`*(db: `dbType`, rec: `t`): `t` =
|
||||||
db.withConnection conn: result = createRecord(conn, rec)
|
db.withConnection conn: result = createRecord(conn, rec)
|
||||||
|
|
||||||
|
proc `createName`*[D: DbConnType](conn: D, rec: `t`): `t` =
|
||||||
|
result = createRecord(conn, rec)
|
||||||
|
|
||||||
proc `updateName`*(db: `dbType`, rec: `t`): bool =
|
proc `updateName`*(db: `dbType`, rec: `t`): bool =
|
||||||
db.withConnection conn: result = updateRecord(conn, rec)
|
db.withConnection conn: result = updateRecord(conn, rec)
|
||||||
|
|
||||||
|
proc `updateName`*[D: DbConnType](conn: D, rec: `t`): bool =
|
||||||
|
result = updateRecord(conn, rec)
|
||||||
|
|
||||||
proc `createOrUpdateName`*(db: `dbType`, rec: `t`): `t` =
|
proc `createOrUpdateName`*(db: `dbType`, rec: `t`): `t` =
|
||||||
db.inTransaction: result = createOrUpdateRecord(conn, rec)
|
db.inTransaction: result = createOrUpdateRecord(conn, rec)
|
||||||
|
|
||||||
|
proc `createOrUpdateName`*[D: DbConnType](conn: D, rec: `t`): `t` =
|
||||||
|
result = createOrUpdateRecord(conn, rec)
|
||||||
|
|
||||||
proc `deleteName`*(db: `dbType`, rec: `t`): bool =
|
proc `deleteName`*(db: `dbType`, rec: `t`): bool =
|
||||||
db.withConnection conn: result = deleteRecord(conn, rec)
|
db.withConnection conn: result = deleteRecord(conn, rec)
|
||||||
|
|
||||||
|
proc `deleteName`*[D: DbConnType](conn: D, rec: `t`): bool =
|
||||||
|
result = deleteRecord(conn, rec)
|
||||||
|
|
||||||
proc `deleteName`*(db: `dbType`, id: `idType`): bool =
|
proc `deleteName`*(db: `dbType`, id: `idType`): bool =
|
||||||
db.withConnection conn: result = deleteRecord(conn, `t`, id)
|
db.withConnection conn: result = deleteRecord(conn, `t`, id)
|
||||||
|
|
||||||
|
proc `deleteName`*[D: DbConnType](conn: D, id: `idType`): bool =
|
||||||
|
result = deleteRecord(conn, `t`, id)
|
||||||
|
|
||||||
macro generateLookup*(dbType: type, modelType: type, fields: seq[string]): untyped =
|
macro generateLookup*(dbType: type, modelType: type, fields: seq[string]): untyped =
|
||||||
## Create a lookup procedure for a given set of field names. For example,
|
## Create a lookup procedure for a given set of field names. For example,
|
||||||
## given the TODO database demostrated above,
|
## given the TODO database demostrated above,
|
||||||
@@ -676,39 +713,39 @@ macro generateLookup*(dbType: type, modelType: type, fields: seq[string]): untyp
|
|||||||
## owner: string, priority: int): seq[TodoItem]
|
## owner: string, priority: int): seq[TodoItem]
|
||||||
let fieldNames = fields[1].mapIt($it)
|
let fieldNames = fields[1].mapIt($it)
|
||||||
let procName = ident("find" & pluralize($modelType.getType[1]) & "By" & fieldNames.mapIt(it.capitalize).join("And"))
|
let procName = ident("find" & pluralize($modelType.getType[1]) & "By" & fieldNames.mapIt(it.capitalize).join("And"))
|
||||||
|
|
||||||
# Create proc skeleton
|
|
||||||
result = quote do:
|
|
||||||
proc `procName`*(db: `dbType`): PagedRecords[`modelType`] =
|
|
||||||
db.withConnection conn: result = findRecordsBy(conn, `modelType`)
|
|
||||||
|
|
||||||
var callParams = quote do: @[]
|
var callParams = quote do: @[]
|
||||||
|
|
||||||
# Add dynamic parameters for the proc definition and inner proc call
|
# Add dynamic parameters for the generated proc and inner proc call.
|
||||||
for n in fieldNames:
|
for n in fieldNames:
|
||||||
let paramTuple = newNimNode(nnkPar)
|
let paramTuple = newNimNode(nnkPar)
|
||||||
paramTuple.add(newColonExpr(ident("field"), newLit(identNameToDb(n))))
|
paramTuple.add(newColonExpr(ident("field"), newLit(identNameToDb(n))))
|
||||||
paramTuple.add(newColonExpr(ident("value"), ident(n)))
|
paramTuple.add(newColonExpr(ident("value"), ident(n)))
|
||||||
|
|
||||||
# Add the parameter to the outer call (the generated proc)
|
|
||||||
# result[3] is ProcDef -> [3]: FormalParams
|
|
||||||
result[3].add(newIdentDefs(ident(n), ident("string")))
|
|
||||||
|
|
||||||
# Build up the AST for the inner procedure call
|
|
||||||
callParams[1].add(paramTuple)
|
callParams[1].add(paramTuple)
|
||||||
|
|
||||||
# Add the optional pagination parameters to the generated proc definition
|
let dbProcDefAST = quote do:
|
||||||
result[3].add(newIdentDefs(
|
proc `procName`*(db: `dbType`): PagedRecords[`modelType`] =
|
||||||
|
db.withConnection conn:
|
||||||
|
result = findRecordsBy(conn, `modelType`, `callParams`, pagination)
|
||||||
|
|
||||||
|
let connProcDefAST = quote do:
|
||||||
|
proc `procName`*[D: DbConnType](conn: D): PagedRecords[`modelType`] =
|
||||||
|
result = findRecordsBy(conn, `modelType`, `callParams.copyNimTree`, pagination)
|
||||||
|
|
||||||
|
for n in fieldNames:
|
||||||
|
dbProcDefAST[3].add(newIdentDefs(ident(n), ident("string")))
|
||||||
|
connProcDefAST[3].add(newIdentDefs(ident(n), ident("string")))
|
||||||
|
|
||||||
|
dbProcDefAST[3].add(newIdentDefs(
|
||||||
ident("pagination"), newEmptyNode(),
|
ident("pagination"), newEmptyNode(),
|
||||||
quote do: none[PaginationParams]()))
|
quote do: none[PaginationParams]()))
|
||||||
|
|
||||||
# Add the call params to the inner procedure call
|
connProcDefAST[3].add(newIdentDefs(
|
||||||
# result[6][0][1][0][1] is
|
ident("pagination"), newEmptyNode(),
|
||||||
# ProcDef -> [6]: StmtList (body) -> [0]: Command ->
|
quote do: none[PaginationParams]()))
|
||||||
# [2]: StmtList (withConnection body) -> [0]: Asgn (result =) ->
|
|
||||||
# [1]: Call (inner findRecords invocation)
|
result = newStmtList()
|
||||||
result[6][0][2][0][1].add(callParams)
|
result.add dbProcDefAST
|
||||||
result[6][0][2][0][1].add(quote do: pagination)
|
result.add connProcDefAST
|
||||||
|
|
||||||
macro generateProcsForFieldLookups*(dbType: type, modelsAndFields: openarray[tuple[t: type, fields: seq[string]]]): untyped =
|
macro generateProcsForFieldLookups*(dbType: type, modelsAndFields: openarray[tuple[t: type, fields: seq[string]]]): untyped =
|
||||||
result = newStmtList()
|
result = newStmtList()
|
||||||
@@ -718,32 +755,38 @@ macro generateProcsForFieldLookups*(dbType: type, modelsAndFields: openarray[tup
|
|||||||
let fieldNames = i[1][1][1].mapIt($it)
|
let fieldNames = i[1][1][1].mapIt($it)
|
||||||
|
|
||||||
let procName = ident("find" & $modelType & "sBy" & fieldNames.mapIt(it.capitalize).join("And"))
|
let procName = ident("find" & $modelType & "sBy" & fieldNames.mapIt(it.capitalize).join("And"))
|
||||||
|
|
||||||
# Create proc skeleton
|
|
||||||
let procDefAST = quote do:
|
|
||||||
proc `procName`*(db: `dbType`): PagedRecords[`modelType`] =
|
|
||||||
db.withConnection conn: result = findRecordsBy(conn, `modelType`)
|
|
||||||
|
|
||||||
var callParams = quote do: @[]
|
var callParams = quote do: @[]
|
||||||
|
|
||||||
# Add dynamic parameters for the proc definition and inner proc call
|
# Add dynamic parameters for the generated proc and inner proc call.
|
||||||
for n in fieldNames:
|
for n in fieldNames:
|
||||||
let paramTuple = newNimNode(nnkPar)
|
let paramTuple = newNimNode(nnkPar)
|
||||||
paramTuple.add(newColonExpr(ident("field"), newLit(identNameToDb(n))))
|
paramTuple.add(newColonExpr(ident("field"), newLit(identNameToDb(n))))
|
||||||
paramTuple.add(newColonExpr(ident("value"), ident(n)))
|
paramTuple.add(newColonExpr(ident("value"), ident(n)))
|
||||||
|
|
||||||
procDefAST[3].add(newIdentDefs(ident(n), ident("string")))
|
|
||||||
callParams[1].add(paramTuple)
|
callParams[1].add(paramTuple)
|
||||||
|
|
||||||
# Add the optional pagination parameters to the generated proc definition
|
let dbProcDefAST = quote do:
|
||||||
procDefAST[3].add(newIdentDefs(
|
proc `procName`*(db: `dbType`): PagedRecords[`modelType`] =
|
||||||
|
db.withConnection conn:
|
||||||
|
result = findRecordsBy(conn, `modelType`, `callParams`, pagination)
|
||||||
|
|
||||||
|
let connProcDefAST = quote do:
|
||||||
|
proc `procName`*[D: DbConnType](conn: D): PagedRecords[`modelType`] =
|
||||||
|
result = findRecordsBy(conn, `modelType`, `callParams.copyNimTree`, pagination)
|
||||||
|
|
||||||
|
for n in fieldNames:
|
||||||
|
dbProcDefAST[3].add(newIdentDefs(ident(n), ident("string")))
|
||||||
|
connProcDefAST[3].add(newIdentDefs(ident(n), ident("string")))
|
||||||
|
|
||||||
|
dbProcDefAST[3].add(newIdentDefs(
|
||||||
ident("pagination"), newEmptyNode(),
|
ident("pagination"), newEmptyNode(),
|
||||||
quote do: none[PaginationParams]()))
|
quote do: none[PaginationParams]()))
|
||||||
|
|
||||||
procDefAST[6][0][1][0][1].add(callParams)
|
connProcDefAST[3].add(newIdentDefs(
|
||||||
procDefAST[6][0][1][0][1].add(quote do: pagination)
|
ident("pagination"), newEmptyNode(),
|
||||||
|
quote do: none[PaginationParams]()))
|
||||||
|
|
||||||
result.add procDefAST
|
result.add dbProcDefAST
|
||||||
|
result.add connProcDefAST
|
||||||
|
|
||||||
macro generateJoinTableProcs*(
|
macro generateJoinTableProcs*(
|
||||||
dbType, model1Type, model2Type: type,
|
dbType, model1Type, model2Type: type,
|
||||||
@@ -791,6 +834,18 @@ macro generateJoinTableProcs*(
|
|||||||
id,
|
id,
|
||||||
pagination)
|
pagination)
|
||||||
|
|
||||||
|
proc `getModel1Name`*[D: DbConnType](
|
||||||
|
conn: D,
|
||||||
|
id: `id2Type`,
|
||||||
|
pagination = none[PaginationParams]()): PagedRecords[`model1Type`] =
|
||||||
|
result = findViaJoinTable(
|
||||||
|
conn,
|
||||||
|
`joinTableNameNode`,
|
||||||
|
`model1Type`,
|
||||||
|
`model2Type`,
|
||||||
|
id,
|
||||||
|
pagination)
|
||||||
|
|
||||||
proc `getModel1Name`*(
|
proc `getModel1Name`*(
|
||||||
db: `dbType`,
|
db: `dbType`,
|
||||||
rec: `model2Type`,
|
rec: `model2Type`,
|
||||||
@@ -803,10 +858,21 @@ macro generateJoinTableProcs*(
|
|||||||
rec,
|
rec,
|
||||||
pagination)
|
pagination)
|
||||||
|
|
||||||
|
proc `getModel1Name`*[D: DbConnType](
|
||||||
|
conn: D,
|
||||||
|
rec: `model2Type`,
|
||||||
|
pagination = none[PaginationParams]()): PagedRecords[`model1Type`] =
|
||||||
|
result = findViaJoinTable(
|
||||||
|
conn,
|
||||||
|
`joinTableNameNode`,
|
||||||
|
`model1Type`,
|
||||||
|
rec,
|
||||||
|
pagination)
|
||||||
|
|
||||||
proc `getModel2Name`*(
|
proc `getModel2Name`*(
|
||||||
db: `dbType`,
|
db: `dbType`,
|
||||||
id: `id1Type`,
|
id: `id1Type`,
|
||||||
pagination = none[PaginationParams]()): Pagedrecords[`model2Type`] =
|
pagination = none[PaginationParams]()): PagedRecords[`model2Type`] =
|
||||||
db.withConnection conn:
|
db.withConnection conn:
|
||||||
result = findViaJoinTable(
|
result = findViaJoinTable(
|
||||||
conn,
|
conn,
|
||||||
@@ -816,10 +882,22 @@ macro generateJoinTableProcs*(
|
|||||||
id,
|
id,
|
||||||
pagination)
|
pagination)
|
||||||
|
|
||||||
|
proc `getModel2Name`*[D: DbConnType](
|
||||||
|
conn: D,
|
||||||
|
id: `id1Type`,
|
||||||
|
pagination = none[PaginationParams]()): PagedRecords[`model2Type`] =
|
||||||
|
result = findViaJoinTable(
|
||||||
|
conn,
|
||||||
|
`joinTableNameNode`,
|
||||||
|
`model2Type`,
|
||||||
|
`model1Type`,
|
||||||
|
id,
|
||||||
|
pagination)
|
||||||
|
|
||||||
proc `getModel2Name`*(
|
proc `getModel2Name`*(
|
||||||
db: `dbType`,
|
db: `dbType`,
|
||||||
rec: `model1Type`,
|
rec: `model1Type`,
|
||||||
pagination = none[PaginationParams]()): Pagedrecords[`model2Type`] =
|
pagination = none[PaginationParams]()): PagedRecords[`model2Type`] =
|
||||||
db.withConnection conn:
|
db.withConnection conn:
|
||||||
result = findViaJoinTable(
|
result = findViaJoinTable(
|
||||||
conn,
|
conn,
|
||||||
@@ -828,6 +906,17 @@ macro generateJoinTableProcs*(
|
|||||||
rec,
|
rec,
|
||||||
pagination)
|
pagination)
|
||||||
|
|
||||||
|
proc `getModel2Name`*[D: DbConnType](
|
||||||
|
conn: D,
|
||||||
|
rec: `model1Type`,
|
||||||
|
pagination = none[PaginationParams]()): PagedRecords[`model2Type`] =
|
||||||
|
result = findViaJoinTable(
|
||||||
|
conn,
|
||||||
|
`joinTableNameNode`,
|
||||||
|
`model2Type`,
|
||||||
|
rec,
|
||||||
|
pagination)
|
||||||
|
|
||||||
proc associate*(
|
proc associate*(
|
||||||
db: `dbType`,
|
db: `dbType`,
|
||||||
rec1: `model1Type`,
|
rec1: `model1Type`,
|
||||||
@@ -835,6 +924,12 @@ macro generateJoinTableProcs*(
|
|||||||
db.withConnection conn:
|
db.withConnection conn:
|
||||||
associate(conn, `joinTableNameNode`, rec1, rec2)
|
associate(conn, `joinTableNameNode`, rec1, rec2)
|
||||||
|
|
||||||
|
proc associate*[D: DbConnType](
|
||||||
|
conn: D,
|
||||||
|
rec1: `model1Type`,
|
||||||
|
rec2: `model2Type`): void =
|
||||||
|
associate(conn, `joinTableNameNode`, rec1, rec2)
|
||||||
|
|
||||||
proc associate*(
|
proc associate*(
|
||||||
db: `dbType`,
|
db: `dbType`,
|
||||||
rec2: `model2Type`,
|
rec2: `model2Type`,
|
||||||
@@ -842,6 +937,12 @@ macro generateJoinTableProcs*(
|
|||||||
db.withConnection conn:
|
db.withConnection conn:
|
||||||
associate(conn, `joinTableNameNode`, rec1, rec2)
|
associate(conn, `joinTableNameNode`, rec1, rec2)
|
||||||
|
|
||||||
|
proc associate*[D: DbConnType](
|
||||||
|
conn: D,
|
||||||
|
rec2: `model2Type`,
|
||||||
|
rec1: `model1Type`): void =
|
||||||
|
associate(conn, `joinTableNameNode`, rec1, rec2)
|
||||||
|
|
||||||
template inTransaction*(db, body: untyped) =
|
template inTransaction*(db, body: untyped) =
|
||||||
db.withConnection conn:
|
db.withConnection conn:
|
||||||
conn.exec(sql"BEGIN TRANSACTION")
|
conn.exec(sql"BEGIN TRANSACTION")
|
||||||
|
|||||||
Reference in New Issue
Block a user