Add connection overloads for generated ORM procs
The generated ORM helpers currently only accept the database wrapper
type as their first argument. That works well for the common case, but
it becomes misleading inside inTransaction blocks because the generated
proc will call withConnection again and may therefore use a different
connection than the one that is participating in the transaction.
Add DbConnType-constrained overloads for the generated model CRUD/query
procs, generated lookups, and generated join-table helpers. This lets
callers explicitly use the transaction connection while keeping the
existing dbType-based API intact for non-transactional call sites.
This makes the intended transactional usage straightforward:
db.inTransaction:
var userRecord = conn.getUser("userId1")
userRecord.visitCount += 1
discard conn.updateUser(userRecord)
AI-Assisted: yes
AI-Tool: OpenAI Codes / gpt-5.4 xhigh
This commit is contained in:
@@ -628,17 +628,32 @@ macro generateProcsForModels*(dbType: type, modelTypes: openarray[type]): untype
|
||||
proc `getName`*(db: `dbType`, id: `idType`): `t` =
|
||||
db.withConnection conn: result = getRecord(conn, `t`, id)
|
||||
|
||||
proc `getName`*[D: DbConnType](conn: D, id: `idType`): `t` =
|
||||
result = getRecord(conn, `t`, id)
|
||||
|
||||
proc `tryGetName`*(db: `dbType`, id: `idType`): Option[`t`] =
|
||||
db.withConnection conn: result = tryGetRecord(conn, `t`, id)
|
||||
|
||||
proc `tryGetName`*[D: DbConnType](conn: D, id: `idType`): Option[`t`] =
|
||||
result = tryGetRecord(conn, `t`, id)
|
||||
|
||||
proc `getIfExistsName`*(db: `dbType`, id: `idType`): Option[`t`] =
|
||||
db.withConnection conn:
|
||||
try: result = some(getRecord(conn, `t`, id))
|
||||
except NotFoundError: result = none[`t`]()
|
||||
|
||||
proc `getIfExistsName`*[D: DbConnType](conn: D, id: `idType`): Option[`t`] =
|
||||
try: result = some(getRecord(conn, `t`, id))
|
||||
except NotFoundError: result = none[`t`]()
|
||||
|
||||
proc `getAllName`*(db: `dbType`, pagination = none[PaginationParams]()): PagedRecords[`t`] =
|
||||
db.withConnection conn: result = getAllRecords(conn, `t`, pagination)
|
||||
|
||||
proc `getAllName`*[D: DbConnType](
|
||||
conn: D,
|
||||
pagination = none[PaginationParams]()): PagedRecords[`t`] =
|
||||
result = getAllRecords(conn, `t`, pagination)
|
||||
|
||||
proc `findWhereName`*(
|
||||
db: `dbType`,
|
||||
whereClause: string,
|
||||
@@ -647,21 +662,43 @@ macro generateProcsForModels*(dbType: type, modelTypes: openarray[type]): untype
|
||||
db.withConnection conn:
|
||||
result = findRecordsWhere(conn, `t`, whereClause, values, pagination)
|
||||
|
||||
proc `findWhereName`*[D: DbConnType](
|
||||
conn: D,
|
||||
whereClause: string,
|
||||
values: varargs[string, dbFormat],
|
||||
pagination = none[PaginationParams]()): PagedRecords[`t`] =
|
||||
result = findRecordsWhere(conn, `t`, whereClause, values, pagination)
|
||||
|
||||
proc `createName`*(db: `dbType`, rec: `t`): `t` =
|
||||
db.withConnection conn: result = createRecord(conn, rec)
|
||||
|
||||
proc `createName`*[D: DbConnType](conn: D, rec: `t`): `t` =
|
||||
result = createRecord(conn, rec)
|
||||
|
||||
proc `updateName`*(db: `dbType`, rec: `t`): bool =
|
||||
db.withConnection conn: result = updateRecord(conn, rec)
|
||||
|
||||
proc `updateName`*[D: DbConnType](conn: D, rec: `t`): bool =
|
||||
result = updateRecord(conn, rec)
|
||||
|
||||
proc `createOrUpdateName`*(db: `dbType`, rec: `t`): `t` =
|
||||
db.inTransaction: result = createOrUpdateRecord(conn, rec)
|
||||
|
||||
proc `createOrUpdateName`*[D: DbConnType](conn: D, rec: `t`): `t` =
|
||||
result = createOrUpdateRecord(conn, rec)
|
||||
|
||||
proc `deleteName`*(db: `dbType`, rec: `t`): bool =
|
||||
db.withConnection conn: result = deleteRecord(conn, rec)
|
||||
|
||||
proc `deleteName`*[D: DbConnType](conn: D, rec: `t`): bool =
|
||||
result = deleteRecord(conn, rec)
|
||||
|
||||
proc `deleteName`*(db: `dbType`, id: `idType`): bool =
|
||||
db.withConnection conn: result = deleteRecord(conn, `t`, id)
|
||||
|
||||
proc `deleteName`*[D: DbConnType](conn: D, id: `idType`): bool =
|
||||
result = deleteRecord(conn, `t`, id)
|
||||
|
||||
macro generateLookup*(dbType: type, modelType: type, fields: seq[string]): untyped =
|
||||
## Create a lookup procedure for a given set of field names. For example,
|
||||
## given the TODO database demostrated above,
|
||||
@@ -676,39 +713,39 @@ macro generateLookup*(dbType: type, modelType: type, fields: seq[string]): untyp
|
||||
## owner: string, priority: int): seq[TodoItem]
|
||||
let fieldNames = fields[1].mapIt($it)
|
||||
let procName = ident("find" & pluralize($modelType.getType[1]) & "By" & fieldNames.mapIt(it.capitalize).join("And"))
|
||||
|
||||
# Create proc skeleton
|
||||
result = quote do:
|
||||
proc `procName`*(db: `dbType`): PagedRecords[`modelType`] =
|
||||
db.withConnection conn: result = findRecordsBy(conn, `modelType`)
|
||||
|
||||
var callParams = quote do: @[]
|
||||
|
||||
# Add dynamic parameters for the proc definition and inner proc call
|
||||
# Add dynamic parameters for the generated proc and inner proc call.
|
||||
for n in fieldNames:
|
||||
let paramTuple = newNimNode(nnkPar)
|
||||
paramTuple.add(newColonExpr(ident("field"), newLit(identNameToDb(n))))
|
||||
paramTuple.add(newColonExpr(ident("value"), ident(n)))
|
||||
|
||||
# Add the parameter to the outer call (the generated proc)
|
||||
# result[3] is ProcDef -> [3]: FormalParams
|
||||
result[3].add(newIdentDefs(ident(n), ident("string")))
|
||||
|
||||
# Build up the AST for the inner procedure call
|
||||
callParams[1].add(paramTuple)
|
||||
|
||||
# Add the optional pagination parameters to the generated proc definition
|
||||
result[3].add(newIdentDefs(
|
||||
let dbProcDefAST = quote do:
|
||||
proc `procName`*(db: `dbType`): PagedRecords[`modelType`] =
|
||||
db.withConnection conn:
|
||||
result = findRecordsBy(conn, `modelType`, `callParams`, pagination)
|
||||
|
||||
let connProcDefAST = quote do:
|
||||
proc `procName`*[D: DbConnType](conn: D): PagedRecords[`modelType`] =
|
||||
result = findRecordsBy(conn, `modelType`, `callParams.copyNimTree`, pagination)
|
||||
|
||||
for n in fieldNames:
|
||||
dbProcDefAST[3].add(newIdentDefs(ident(n), ident("string")))
|
||||
connProcDefAST[3].add(newIdentDefs(ident(n), ident("string")))
|
||||
|
||||
dbProcDefAST[3].add(newIdentDefs(
|
||||
ident("pagination"), newEmptyNode(),
|
||||
quote do: none[PaginationParams]()))
|
||||
|
||||
# Add the call params to the inner procedure call
|
||||
# result[6][0][1][0][1] is
|
||||
# ProcDef -> [6]: StmtList (body) -> [0]: Command ->
|
||||
# [2]: StmtList (withConnection body) -> [0]: Asgn (result =) ->
|
||||
# [1]: Call (inner findRecords invocation)
|
||||
result[6][0][2][0][1].add(callParams)
|
||||
result[6][0][2][0][1].add(quote do: pagination)
|
||||
connProcDefAST[3].add(newIdentDefs(
|
||||
ident("pagination"), newEmptyNode(),
|
||||
quote do: none[PaginationParams]()))
|
||||
|
||||
result = newStmtList()
|
||||
result.add dbProcDefAST
|
||||
result.add connProcDefAST
|
||||
|
||||
macro generateProcsForFieldLookups*(dbType: type, modelsAndFields: openarray[tuple[t: type, fields: seq[string]]]): untyped =
|
||||
result = newStmtList()
|
||||
@@ -718,32 +755,38 @@ macro generateProcsForFieldLookups*(dbType: type, modelsAndFields: openarray[tup
|
||||
let fieldNames = i[1][1][1].mapIt($it)
|
||||
|
||||
let procName = ident("find" & $modelType & "sBy" & fieldNames.mapIt(it.capitalize).join("And"))
|
||||
|
||||
# Create proc skeleton
|
||||
let procDefAST = quote do:
|
||||
proc `procName`*(db: `dbType`): PagedRecords[`modelType`] =
|
||||
db.withConnection conn: result = findRecordsBy(conn, `modelType`)
|
||||
|
||||
var callParams = quote do: @[]
|
||||
|
||||
# Add dynamic parameters for the proc definition and inner proc call
|
||||
# Add dynamic parameters for the generated proc and inner proc call.
|
||||
for n in fieldNames:
|
||||
let paramTuple = newNimNode(nnkPar)
|
||||
paramTuple.add(newColonExpr(ident("field"), newLit(identNameToDb(n))))
|
||||
paramTuple.add(newColonExpr(ident("value"), ident(n)))
|
||||
|
||||
procDefAST[3].add(newIdentDefs(ident(n), ident("string")))
|
||||
callParams[1].add(paramTuple)
|
||||
|
||||
# Add the optional pagination parameters to the generated proc definition
|
||||
procDefAST[3].add(newIdentDefs(
|
||||
let dbProcDefAST = quote do:
|
||||
proc `procName`*(db: `dbType`): PagedRecords[`modelType`] =
|
||||
db.withConnection conn:
|
||||
result = findRecordsBy(conn, `modelType`, `callParams`, pagination)
|
||||
|
||||
let connProcDefAST = quote do:
|
||||
proc `procName`*[D: DbConnType](conn: D): PagedRecords[`modelType`] =
|
||||
result = findRecordsBy(conn, `modelType`, `callParams.copyNimTree`, pagination)
|
||||
|
||||
for n in fieldNames:
|
||||
dbProcDefAST[3].add(newIdentDefs(ident(n), ident("string")))
|
||||
connProcDefAST[3].add(newIdentDefs(ident(n), ident("string")))
|
||||
|
||||
dbProcDefAST[3].add(newIdentDefs(
|
||||
ident("pagination"), newEmptyNode(),
|
||||
quote do: none[PaginationParams]()))
|
||||
|
||||
procDefAST[6][0][1][0][1].add(callParams)
|
||||
procDefAST[6][0][1][0][1].add(quote do: pagination)
|
||||
connProcDefAST[3].add(newIdentDefs(
|
||||
ident("pagination"), newEmptyNode(),
|
||||
quote do: none[PaginationParams]()))
|
||||
|
||||
result.add procDefAST
|
||||
result.add dbProcDefAST
|
||||
result.add connProcDefAST
|
||||
|
||||
macro generateJoinTableProcs*(
|
||||
dbType, model1Type, model2Type: type,
|
||||
@@ -791,6 +834,18 @@ macro generateJoinTableProcs*(
|
||||
id,
|
||||
pagination)
|
||||
|
||||
proc `getModel1Name`*[D: DbConnType](
|
||||
conn: D,
|
||||
id: `id2Type`,
|
||||
pagination = none[PaginationParams]()): PagedRecords[`model1Type`] =
|
||||
result = findViaJoinTable(
|
||||
conn,
|
||||
`joinTableNameNode`,
|
||||
`model1Type`,
|
||||
`model2Type`,
|
||||
id,
|
||||
pagination)
|
||||
|
||||
proc `getModel1Name`*(
|
||||
db: `dbType`,
|
||||
rec: `model2Type`,
|
||||
@@ -803,10 +858,21 @@ macro generateJoinTableProcs*(
|
||||
rec,
|
||||
pagination)
|
||||
|
||||
proc `getModel1Name`*[D: DbConnType](
|
||||
conn: D,
|
||||
rec: `model2Type`,
|
||||
pagination = none[PaginationParams]()): PagedRecords[`model1Type`] =
|
||||
result = findViaJoinTable(
|
||||
conn,
|
||||
`joinTableNameNode`,
|
||||
`model1Type`,
|
||||
rec,
|
||||
pagination)
|
||||
|
||||
proc `getModel2Name`*(
|
||||
db: `dbType`,
|
||||
id: `id1Type`,
|
||||
pagination = none[PaginationParams]()): Pagedrecords[`model2Type`] =
|
||||
pagination = none[PaginationParams]()): PagedRecords[`model2Type`] =
|
||||
db.withConnection conn:
|
||||
result = findViaJoinTable(
|
||||
conn,
|
||||
@@ -816,10 +882,22 @@ macro generateJoinTableProcs*(
|
||||
id,
|
||||
pagination)
|
||||
|
||||
proc `getModel2Name`*[D: DbConnType](
|
||||
conn: D,
|
||||
id: `id1Type`,
|
||||
pagination = none[PaginationParams]()): PagedRecords[`model2Type`] =
|
||||
result = findViaJoinTable(
|
||||
conn,
|
||||
`joinTableNameNode`,
|
||||
`model2Type`,
|
||||
`model1Type`,
|
||||
id,
|
||||
pagination)
|
||||
|
||||
proc `getModel2Name`*(
|
||||
db: `dbType`,
|
||||
rec: `model1Type`,
|
||||
pagination = none[PaginationParams]()): Pagedrecords[`model2Type`] =
|
||||
pagination = none[PaginationParams]()): PagedRecords[`model2Type`] =
|
||||
db.withConnection conn:
|
||||
result = findViaJoinTable(
|
||||
conn,
|
||||
@@ -828,6 +906,17 @@ macro generateJoinTableProcs*(
|
||||
rec,
|
||||
pagination)
|
||||
|
||||
proc `getModel2Name`*[D: DbConnType](
|
||||
conn: D,
|
||||
rec: `model1Type`,
|
||||
pagination = none[PaginationParams]()): PagedRecords[`model2Type`] =
|
||||
result = findViaJoinTable(
|
||||
conn,
|
||||
`joinTableNameNode`,
|
||||
`model2Type`,
|
||||
rec,
|
||||
pagination)
|
||||
|
||||
proc associate*(
|
||||
db: `dbType`,
|
||||
rec1: `model1Type`,
|
||||
@@ -835,6 +924,12 @@ macro generateJoinTableProcs*(
|
||||
db.withConnection conn:
|
||||
associate(conn, `joinTableNameNode`, rec1, rec2)
|
||||
|
||||
proc associate*[D: DbConnType](
|
||||
conn: D,
|
||||
rec1: `model1Type`,
|
||||
rec2: `model2Type`): void =
|
||||
associate(conn, `joinTableNameNode`, rec1, rec2)
|
||||
|
||||
proc associate*(
|
||||
db: `dbType`,
|
||||
rec2: `model2Type`,
|
||||
@@ -842,6 +937,12 @@ macro generateJoinTableProcs*(
|
||||
db.withConnection conn:
|
||||
associate(conn, `joinTableNameNode`, rec1, rec2)
|
||||
|
||||
proc associate*[D: DbConnType](
|
||||
conn: D,
|
||||
rec2: `model2Type`,
|
||||
rec1: `model1Type`): void =
|
||||
associate(conn, `joinTableNameNode`, rec1, rec2)
|
||||
|
||||
template inTransaction*(db, body: untyped) =
|
||||
db.withConnection conn:
|
||||
conn.exec(sql"BEGIN TRANSACTION")
|
||||
|
||||
Reference in New Issue
Block a user