diff --git a/src/fiber_orm.nim b/src/fiber_orm.nim index b0d524c..62fbae7 100644 --- a/src/fiber_orm.nim +++ b/src/fiber_orm.nim @@ -628,17 +628,32 @@ macro generateProcsForModels*(dbType: type, modelTypes: openarray[type]): untype proc `getName`*(db: `dbType`, id: `idType`): `t` = db.withConnection conn: result = getRecord(conn, `t`, id) + proc `getName`*[D: DbConnType](conn: D, id: `idType`): `t` = + result = getRecord(conn, `t`, id) + proc `tryGetName`*(db: `dbType`, id: `idType`): Option[`t`] = db.withConnection conn: result = tryGetRecord(conn, `t`, id) + proc `tryGetName`*[D: DbConnType](conn: D, id: `idType`): Option[`t`] = + result = tryGetRecord(conn, `t`, id) + proc `getIfExistsName`*(db: `dbType`, id: `idType`): Option[`t`] = db.withConnection conn: try: result = some(getRecord(conn, `t`, id)) except NotFoundError: result = none[`t`]() + proc `getIfExistsName`*[D: DbConnType](conn: D, id: `idType`): Option[`t`] = + try: result = some(getRecord(conn, `t`, id)) + except NotFoundError: result = none[`t`]() + proc `getAllName`*(db: `dbType`, pagination = none[PaginationParams]()): PagedRecords[`t`] = db.withConnection conn: result = getAllRecords(conn, `t`, pagination) + proc `getAllName`*[D: DbConnType]( + conn: D, + pagination = none[PaginationParams]()): PagedRecords[`t`] = + result = getAllRecords(conn, `t`, pagination) + proc `findWhereName`*( db: `dbType`, whereClause: string, @@ -647,21 +662,43 @@ macro generateProcsForModels*(dbType: type, modelTypes: openarray[type]): untype db.withConnection conn: result = findRecordsWhere(conn, `t`, whereClause, values, pagination) + proc `findWhereName`*[D: DbConnType]( + conn: D, + whereClause: string, + values: varargs[string, dbFormat], + pagination = none[PaginationParams]()): PagedRecords[`t`] = + result = findRecordsWhere(conn, `t`, whereClause, values, pagination) + proc `createName`*(db: `dbType`, rec: `t`): `t` = db.withConnection conn: result = createRecord(conn, rec) + proc `createName`*[D: DbConnType](conn: D, rec: `t`): `t` = + result = createRecord(conn, rec) + proc `updateName`*(db: `dbType`, rec: `t`): bool = db.withConnection conn: result = updateRecord(conn, rec) + proc `updateName`*[D: DbConnType](conn: D, rec: `t`): bool = + result = updateRecord(conn, rec) + proc `createOrUpdateName`*(db: `dbType`, rec: `t`): `t` = db.inTransaction: result = createOrUpdateRecord(conn, rec) + proc `createOrUpdateName`*[D: DbConnType](conn: D, rec: `t`): `t` = + result = createOrUpdateRecord(conn, rec) + proc `deleteName`*(db: `dbType`, rec: `t`): bool = db.withConnection conn: result = deleteRecord(conn, rec) + proc `deleteName`*[D: DbConnType](conn: D, rec: `t`): bool = + result = deleteRecord(conn, rec) + proc `deleteName`*(db: `dbType`, id: `idType`): bool = db.withConnection conn: result = deleteRecord(conn, `t`, id) + proc `deleteName`*[D: DbConnType](conn: D, id: `idType`): bool = + result = deleteRecord(conn, `t`, id) + macro generateLookup*(dbType: type, modelType: type, fields: seq[string]): untyped = ## Create a lookup procedure for a given set of field names. For example, ## given the TODO database demostrated above, @@ -676,39 +713,39 @@ macro generateLookup*(dbType: type, modelType: type, fields: seq[string]): untyp ## owner: string, priority: int): seq[TodoItem] let fieldNames = fields[1].mapIt($it) let procName = ident("find" & pluralize($modelType.getType[1]) & "By" & fieldNames.mapIt(it.capitalize).join("And")) - - # Create proc skeleton - result = quote do: - proc `procName`*(db: `dbType`): PagedRecords[`modelType`] = - db.withConnection conn: result = findRecordsBy(conn, `modelType`) - var callParams = quote do: @[] - # Add dynamic parameters for the proc definition and inner proc call + # Add dynamic parameters for the generated proc and inner proc call. for n in fieldNames: let paramTuple = newNimNode(nnkPar) paramTuple.add(newColonExpr(ident("field"), newLit(identNameToDb(n)))) paramTuple.add(newColonExpr(ident("value"), ident(n))) - - # Add the parameter to the outer call (the generated proc) - # result[3] is ProcDef -> [3]: FormalParams - result[3].add(newIdentDefs(ident(n), ident("string"))) - - # Build up the AST for the inner procedure call callParams[1].add(paramTuple) - # Add the optional pagination parameters to the generated proc definition - result[3].add(newIdentDefs( + let dbProcDefAST = quote do: + proc `procName`*(db: `dbType`): PagedRecords[`modelType`] = + db.withConnection conn: + result = findRecordsBy(conn, `modelType`, `callParams`, pagination) + + let connProcDefAST = quote do: + proc `procName`*[D: DbConnType](conn: D): PagedRecords[`modelType`] = + result = findRecordsBy(conn, `modelType`, `callParams.copyNimTree`, pagination) + + for n in fieldNames: + dbProcDefAST[3].add(newIdentDefs(ident(n), ident("string"))) + connProcDefAST[3].add(newIdentDefs(ident(n), ident("string"))) + + dbProcDefAST[3].add(newIdentDefs( ident("pagination"), newEmptyNode(), quote do: none[PaginationParams]())) - # Add the call params to the inner procedure call - # result[6][0][1][0][1] is - # ProcDef -> [6]: StmtList (body) -> [0]: Command -> - # [2]: StmtList (withConnection body) -> [0]: Asgn (result =) -> - # [1]: Call (inner findRecords invocation) - result[6][0][2][0][1].add(callParams) - result[6][0][2][0][1].add(quote do: pagination) + connProcDefAST[3].add(newIdentDefs( + ident("pagination"), newEmptyNode(), + quote do: none[PaginationParams]())) + + result = newStmtList() + result.add dbProcDefAST + result.add connProcDefAST macro generateProcsForFieldLookups*(dbType: type, modelsAndFields: openarray[tuple[t: type, fields: seq[string]]]): untyped = result = newStmtList() @@ -718,32 +755,38 @@ macro generateProcsForFieldLookups*(dbType: type, modelsAndFields: openarray[tup let fieldNames = i[1][1][1].mapIt($it) let procName = ident("find" & $modelType & "sBy" & fieldNames.mapIt(it.capitalize).join("And")) - - # Create proc skeleton - let procDefAST = quote do: - proc `procName`*(db: `dbType`): PagedRecords[`modelType`] = - db.withConnection conn: result = findRecordsBy(conn, `modelType`) - var callParams = quote do: @[] - # Add dynamic parameters for the proc definition and inner proc call + # Add dynamic parameters for the generated proc and inner proc call. for n in fieldNames: let paramTuple = newNimNode(nnkPar) paramTuple.add(newColonExpr(ident("field"), newLit(identNameToDb(n)))) paramTuple.add(newColonExpr(ident("value"), ident(n))) - - procDefAST[3].add(newIdentDefs(ident(n), ident("string"))) callParams[1].add(paramTuple) - # Add the optional pagination parameters to the generated proc definition - procDefAST[3].add(newIdentDefs( + let dbProcDefAST = quote do: + proc `procName`*(db: `dbType`): PagedRecords[`modelType`] = + db.withConnection conn: + result = findRecordsBy(conn, `modelType`, `callParams`, pagination) + + let connProcDefAST = quote do: + proc `procName`*[D: DbConnType](conn: D): PagedRecords[`modelType`] = + result = findRecordsBy(conn, `modelType`, `callParams.copyNimTree`, pagination) + + for n in fieldNames: + dbProcDefAST[3].add(newIdentDefs(ident(n), ident("string"))) + connProcDefAST[3].add(newIdentDefs(ident(n), ident("string"))) + + dbProcDefAST[3].add(newIdentDefs( ident("pagination"), newEmptyNode(), quote do: none[PaginationParams]())) - procDefAST[6][0][1][0][1].add(callParams) - procDefAST[6][0][1][0][1].add(quote do: pagination) + connProcDefAST[3].add(newIdentDefs( + ident("pagination"), newEmptyNode(), + quote do: none[PaginationParams]())) - result.add procDefAST + result.add dbProcDefAST + result.add connProcDefAST macro generateJoinTableProcs*( dbType, model1Type, model2Type: type, @@ -791,6 +834,18 @@ macro generateJoinTableProcs*( id, pagination) + proc `getModel1Name`*[D: DbConnType]( + conn: D, + id: `id2Type`, + pagination = none[PaginationParams]()): PagedRecords[`model1Type`] = + result = findViaJoinTable( + conn, + `joinTableNameNode`, + `model1Type`, + `model2Type`, + id, + pagination) + proc `getModel1Name`*( db: `dbType`, rec: `model2Type`, @@ -803,10 +858,21 @@ macro generateJoinTableProcs*( rec, pagination) + proc `getModel1Name`*[D: DbConnType]( + conn: D, + rec: `model2Type`, + pagination = none[PaginationParams]()): PagedRecords[`model1Type`] = + result = findViaJoinTable( + conn, + `joinTableNameNode`, + `model1Type`, + rec, + pagination) + proc `getModel2Name`*( db: `dbType`, id: `id1Type`, - pagination = none[PaginationParams]()): Pagedrecords[`model2Type`] = + pagination = none[PaginationParams]()): PagedRecords[`model2Type`] = db.withConnection conn: result = findViaJoinTable( conn, @@ -816,10 +882,22 @@ macro generateJoinTableProcs*( id, pagination) + proc `getModel2Name`*[D: DbConnType]( + conn: D, + id: `id1Type`, + pagination = none[PaginationParams]()): PagedRecords[`model2Type`] = + result = findViaJoinTable( + conn, + `joinTableNameNode`, + `model2Type`, + `model1Type`, + id, + pagination) + proc `getModel2Name`*( db: `dbType`, rec: `model1Type`, - pagination = none[PaginationParams]()): Pagedrecords[`model2Type`] = + pagination = none[PaginationParams]()): PagedRecords[`model2Type`] = db.withConnection conn: result = findViaJoinTable( conn, @@ -828,6 +906,17 @@ macro generateJoinTableProcs*( rec, pagination) + proc `getModel2Name`*[D: DbConnType]( + conn: D, + rec: `model1Type`, + pagination = none[PaginationParams]()): PagedRecords[`model2Type`] = + result = findViaJoinTable( + conn, + `joinTableNameNode`, + `model2Type`, + rec, + pagination) + proc associate*( db: `dbType`, rec1: `model1Type`, @@ -835,6 +924,12 @@ macro generateJoinTableProcs*( db.withConnection conn: associate(conn, `joinTableNameNode`, rec1, rec2) + proc associate*[D: DbConnType]( + conn: D, + rec1: `model1Type`, + rec2: `model2Type`): void = + associate(conn, `joinTableNameNode`, rec1, rec2) + proc associate*( db: `dbType`, rec2: `model2Type`, @@ -842,6 +937,12 @@ macro generateJoinTableProcs*( db.withConnection conn: associate(conn, `joinTableNameNode`, rec1, rec2) + proc associate*[D: DbConnType]( + conn: D, + rec2: `model2Type`, + rec1: `model1Type`): void = + associate(conn, `joinTableNameNode`, rec1, rec2) + template inTransaction*(db, body: untyped) = db.withConnection conn: conn.exec(sql"BEGIN TRANSACTION")