diff --git a/fiber_orm.nimble b/fiber_orm.nimble index 93b45f6..3b5359b 100644 --- a/fiber_orm.nimble +++ b/fiber_orm.nimble @@ -1,6 +1,6 @@ # Package -version = "2.0.0" +version = "2.1.0" author = "Jonathan Bernard" description = "Lightweight Postgres ORM for Nim." license = "GPL-3.0" diff --git a/src/fiber_orm.nim b/src/fiber_orm.nim index db2ee5a..8fbc560 100644 --- a/src/fiber_orm.nim +++ b/src/fiber_orm.nim @@ -282,11 +282,12 @@ ## ## .. _pool.DbConnPool: fiber_orm/pool.html#DbConnPool ## -import std/db_postgres, std/macros, std/options, std/sequtils, std/strutils +import std/[db_common, logging, macros, options, sequtils, strutils] import namespaced_logging, uuids from std/unicode import capitalize +import ./fiber_orm/db_common as fiber_db_common import ./fiber_orm/pool import ./fiber_orm/util @@ -317,7 +318,7 @@ type var logNs {.threadvar.}: LoggingNamespace template log(): untyped = - if logNs.isNil: logNs = initLoggingNamespace(name = "fiber_orm", level = lvlNotice) + if logNs.isNil: logNs = getLoggerForNamespace(namespace = "fiber_orm", level = lvlNotice) logNs proc newMutateClauses(): MutateClauses = @@ -326,7 +327,7 @@ proc newMutateClauses(): MutateClauses = placeholders: @[], values: @[]) -proc createRecord*[T](db: DbConn, rec: T): T = +proc createRecord*[D: DbConnType, T](db: D, rec: T): T = ## Create a new record. `rec` is expected to be a `model class`_. The `id` ## field is only set if it is non-empty (see `ID Field`_ for details). ## @@ -349,7 +350,7 @@ proc createRecord*[T](db: DbConn, rec: T): T = result = rowToModel(T, newRow) -proc updateRecord*[T](db: DbConn, rec: T): bool = +proc updateRecord*[D: DbConnType, T](db: D, rec: T): bool = ## Update a record by id. `rec` is expected to be a `model class`_. var mc = newMutateClauses() populateMutateClauses(rec, false, mc) @@ -365,13 +366,13 @@ proc updateRecord*[T](db: DbConn, rec: T): bool = return numRowsUpdated > 0; -template deleteRecord*(db: DbConn, modelType: type, id: typed): untyped = +template deleteRecord*[D: DbConnType](db: D, modelType: type, id: typed): untyped = ## Delete a record by id. let sqlStmt = "DELETE FROM " & tableName(modelType) & " WHERE id = ?" log().debug "deleteRecord: [" & sqlStmt & "] id: " & $id db.tryExec(sql(sqlStmt), $id) -proc deleteRecord*[T](db: DbConn, rec: T): bool = +proc deleteRecord*[D: DbConnType, T](db: D, rec: T): bool = ## Delete a record by `id`_. ## ## .. _id: #model-class-id-field @@ -379,7 +380,7 @@ proc deleteRecord*[T](db: DbConn, rec: T): bool = log().debug "deleteRecord: [" & sqlStmt & "] id: " & $rec.id return db.tryExec(sql(sqlStmt), $rec.id) -template getRecord*(db: DbConn, modelType: type, id: typed): untyped = +template getRecord*[D: DbConnType](db: D, modelType: type, id: typed): untyped = ## Fetch a record by id. let sqlStmt = "SELECT " & columnNamesForModel(modelType).join(",") & @@ -394,8 +395,8 @@ template getRecord*(db: DbConn, modelType: type, id: typed): untyped = rowToModel(modelType, row) -template findRecordsWhere*( - db: DbConn, +template findRecordsWhere*[D: DbConnType]( + db: D, modelType: type, whereClause: string, values: varargs[string, dbFormat], @@ -432,8 +433,8 @@ template findRecordsWhere*( if page.isNone: records.len else: db.getRow(sql(countStmt), values)[0].parseInt) -template getAllRecords*( - db: DbConn, +template getAllRecords*[D: DbConnType]( + db: D, modelType: type, page: Option[PaginationParams]): untyped = ## Fetch all records of the given type. @@ -464,8 +465,8 @@ template getAllRecords*( else: db.getRow(sql(countStmt))[0].parseInt) -template findRecordsBy*( - db: DbConn, +template findRecordsBy*[D: DbConnType]( + db: D, modelType: type, lookups: seq[tuple[field: string, value: string]], page: Option[PaginationParams]): untyped = @@ -526,6 +527,10 @@ macro generateProcsForModels*(dbType: type, modelTypes: openarray[type]): untype result = newStmtList() for t in modelTypes: + if t.getType[1].typeKind == ntyRef: + raise newException(ValueError, + "fiber_orm model object must be objects, not refs") + let modelName = $(t.getType[1]) let getName = ident("get" & modelName) let getAllName = ident("getAll" & pluralize(modelName)) @@ -644,11 +649,12 @@ macro generateProcsForFieldLookups*(dbType: type, modelsAndFields: openarray[tup result.add procDefAST -proc initPool*( - connect: proc(): DbConn, +proc initPool*[D: DbConnType]( + connect: proc(): D, poolSize = 10, hardCap = false, - healthCheckQuery = "SELECT 'true' AS alive"): DbConnPool = + healthCheckQuery = "SELECT 'true' AS alive"): DbConnPool[D] = + ## Initialize a new DbConnPool. See the `initDb` procedure in the `Example ## Fiber ORM Usage`_ for an example ## @@ -666,13 +672,13 @@ proc initPool*( ## ## .. _Example Fiber ORM Usage: #basic-usage-example-fiber-orm-usage - initDbConnPool(DbConnPoolConfig( + initDbConnPool(DbConnPoolConfig[D]( connect: connect, poolSize: poolSize, hardCap: hardCap, healthCheckQuery: healthCheckQuery)) -template inTransaction*(db: DbConnPool, body: untyped) = +template inTransaction*[D: DbConnType](db: DbConnPool[D], body: untyped) = pool.withConn(db): conn.exec(sql"BEGIN TRANSACTION") try: diff --git a/src/fiber_orm/db_common.nim b/src/fiber_orm/db_common.nim new file mode 100644 index 0000000..3736ea8 --- /dev/null +++ b/src/fiber_orm/db_common.nim @@ -0,0 +1,3 @@ +import std/[db_postgres, db_sqlite] + +type DbConnType* = db_postgres.DbConn or db_sqlite.DbConn diff --git a/src/fiber_orm/pool.nim b/src/fiber_orm/pool.nim index fb8ad2e..3177c3f 100644 --- a/src/fiber_orm/pool.nim +++ b/src/fiber_orm/pool.nim @@ -4,65 +4,70 @@ ## Simple database connection pooling implementation compatible with Fiber ORM. -import std/db_postgres, std/sequtils, std/strutils, std/sugar +import std/[db_common, logging, sequtils, strutils, sugar] + +from std/db_sqlite import getRow +from std/db_postgres import getRow import namespaced_logging - +import ./db_common as fiber_db_common type - DbConnPoolConfig* = object - connect*: () -> DbConn ## Factory procedure to create a new DBConn - poolSize*: int ## The pool capacity. - hardCap*: bool ## Is the pool capacity a hard cap? - ## - ## When `false`, the pool can grow beyond the configured - ## capacity, but will release connections down to the its - ## capacity (no less than `poolSize`). - ## - ## When `true` the pool will not create more than its - ## configured capacity. It a connection is requested, none - ## are free, and the pool is at capacity, this will result - ## in an Error being raised. + DbConnPoolConfig*[D: DbConnType] = object + connect*: () -> D ## Factory procedure to create a new DBConn + poolSize*: int ## The pool capacity. + + hardCap*: bool ## Is the pool capacity a hard cap? + ## + ## When `false`, the pool can grow beyond the + ## configured capacity, but will release connections + ## down to the its capacity (no less than `poolSize`). + ## + ## When `true` the pool will not create more than its + ## configured capacity. It a connection is requested, + ## none are free, and the pool is at capacity, this + ## will result in an Error being raised. + healthCheckQuery*: string ## Should be a simple and fast SQL query that the ## pool can use to test the liveliness of pooled ## connections. - PooledDbConn = ref object - conn: DbConn + PooledDbConn[D: DbConnType] = ref object + conn: D id: int free: bool - DbConnPool* = ref object + DbConnPool*[D: DbConnType] = ref object ## Database connection pool - conns: seq[PooledDbConn] - cfg: DbConnPoolConfig + conns: seq[PooledDbConn[D]] + cfg: DbConnPoolConfig[D] lastId: int var logNs {.threadvar.}: LoggingNamespace template log(): untyped = - if logNs.isNil: logNs = initLoggingNamespace(name = "fiber_orm/pool", level = lvlNotice) + if logNs.isNil: logNs = getLoggerForNamespace(namespace = "fiber_orm/pool", level = lvlNotice) logNs -proc initDbConnPool*(cfg: DbConnPoolConfig): DbConnPool = +proc initDbConnPool*[D: DbConnType](cfg: DbConnPoolConfig[D]): DbConnPool[D] = log().debug("Initializing new pool (size: " & $cfg.poolSize) - result = DbConnPool( + result = DbConnPool[D]( conns: @[], cfg: cfg) -proc newConn(pool: DbConnPool): PooledDbConn = +proc newConn[D: DbConnType](pool: DbConnPool[D]): PooledDbConn[D] = log().debug("Creating a new connection to add to the pool.") pool.lastId += 1 let conn = pool.cfg.connect() - result = PooledDbConn( + result = PooledDbConn[D]( conn: conn, id: pool.lastId, free: true) pool.conns.add(result) -proc maintain(pool: DbConnPool): void = +proc maintain[D: DbConnType](pool: DbConnPool[D]): void = log().debug("Maintaining pool. $# connections." % [$pool.conns.len]) - pool.conns.keepIf(proc (pc: PooledDbConn): bool = + pool.conns.keepIf(proc (pc: PooledDbConn[D]): bool = if not pc.free: return true try: @@ -91,7 +96,7 @@ proc maintain(pool: DbConnPool): void = "Trimming pool size. Culled $# free connections. $# connections remaining." % [$toCull.len, $pool.conns.len]) -proc take*(pool: DbConnPool): tuple[id: int, conn: DbConn] = +proc take*[D: DbConnType](pool: DbConnPool[D]): tuple[id: int, conn: D] = ## Request a connection from the pool. Returns a DbConn if the pool has free ## connections, or if it has the capacity to create a new connection. If the ## pool is configured with a hard capacity limit and is out of free @@ -113,13 +118,13 @@ proc take*(pool: DbConnPool): tuple[id: int, conn: DbConn] = log().debug("Reserve connection $#" % [$reserved.id]) return (id: reserved.id, conn: reserved.conn) -proc release*(pool: DbConnPool, connId: int): void = +proc release*[D: DbConnType](pool: DbConnPool[D], connId: int): void = ## Release a connection back to the pool. log().debug("Reclaiming released connaction $#" % [$connId]) let foundConn = pool.conns.filterIt(it.id == connId) if foundConn.len > 0: foundConn[0].free = true -template withConn*(pool: DbConnPool, stmt: untyped): untyped = +template withConn*[D: DbConnType](pool: DbConnPool[D], stmt: untyped): untyped = ## Convenience template to provide a connection from the pool for use in a ## statement block, automatically releasing that connnection when done. ## diff --git a/src/fiber_orm/util.nim b/src/fiber_orm/util.nim index f335aed..687562e 100644 --- a/src/fiber_orm/util.nim +++ b/src/fiber_orm/util.nim @@ -3,10 +3,10 @@ # Copyright 2019 Jonathan Bernard ## Utility methods used internally by Fiber ORM. -import json, macros, options, sequtils, strutils, times, unicode, - uuids +import std/[json, macros, options, sequtils, strutils, times, unicode] +import uuids -import nre except toSeq +import std/nre except toSeq type MutateClauses* = object @@ -207,21 +207,14 @@ proc parseDbArray*(val: string): seq[string] = if not (parseState == inQuote) and curStr.len > 0: result.add(curStr) -proc createParseStmt*(t, value: NimNode): NimNode = +func createParseStmt*(t, value: NimNode): NimNode = ## Utility method to create the Nim cod required to parse a value coming from ## the a database query. This is used by functions like `rowToModel` to parse ## the dataabase columns into the Nim object fields. - #echo "Creating parse statment for ", t.treeRepr if t.typeKind == ntyObject: - if t.getType == UUID.getType: - result = quote do: parseUUID(`value`) - - elif t.getType == DateTime.getType: - result = quote do: parsePGDatetime(`value`) - - elif t.getTypeInst == Option.getType: + if t.getTypeInst == Option.getType: var innerType = t.getTypeImpl[2][0] # start at the first RecList # If the value is a non-pointer type, there is another inner RecList if innerType.kind == nnkRecList: innerType = innerType[0] @@ -232,8 +225,28 @@ proc createParseStmt*(t, value: NimNode): NimNode = if `value`.len == 0: none[`innerType`]() else: some(`parseStmt`) + elif t.getType == UUID.getType: + result = quote do: parseUUID(`value`) + + elif t.getType == DateTime.getType: + result = quote do: parsePGDatetime(`value`) + else: error "Unknown value object type: " & $t.getTypeInst + elif t.typeKind == ntyGenericInst: + + if t.kind == nnkBracketExpr and + t.len > 0 and + t[0] == Option.getType: + + var innerType = t.getTypeInst[1] + let parseStmt = createParseStmt(innerType, value) + result = quote do: + if `value`.len == 0: none[`innerType`]() + else: some(`parseStmt`) + + else: error "Unknown generic instance type: " & $t.getTypeInst + elif t.typeKind == ntyRef: if $t.getTypeInst == "JsonNode": @@ -268,28 +281,72 @@ proc createParseStmt*(t, value: NimNode): NimNode = else: error "Unknown value type: " & $t.typeKind +func fields(t: NimNode): seq[tuple[fieldIdent: NimNode, fieldType: NimNode]] = + #[ + debugEcho "T: " & t.treeRepr + debugEcho "T.kind: " & $t.kind + debugEcho "T.typeKind: " & $t.typeKind + debugEcho "T.GET_TYPE[1]: " & t.getType[1].treeRepr + debugEcho "T.GET_TYPE[1].kind: " & $t.getType[1].kind + debugEcho "T.GET_TYPE[1].typeKind: " & $t.getType[1].typeKind + + debugEcho "T.GET_TYPE: " & t.getType.treeRepr + debugEcho "T.GET_TYPE[1].GET_TYPE: " & t.getType[1].getType.treeRepr + ]# + + # Get the object type AST, with base object (if present) and record list. + var objDefAst: NimNode + if t.typeKind == ntyObject: objDefAst = t.getType + elif t.typeKind == ntyTypeDesc: + # In this case we have a type AST that is like: + # BracketExpr + # Sym "typeDesc" + # Sym "ModelType" + objDefAst = t. + getType[1]. # get the Sym "ModelType" + getType # get the object definition type + + if objDefAst.kind != nnkObjectTy: + error ("unable to enumerate the fields for model type '$#', " & + "tried to resolve the type of the provided symbol to an object " & + "definition (nnkObjectTy) but got a '$#'.\pAST:\p$#") % [ + $t, $objDefAst.kind, objDefAst.treeRepr ] + else: + error ("unable to enumerate the fields for model type '$#', " & + "expected a symbol with type ntyTypeDesc but got a '$#'.\pAST:\p$#") % [ + $t, $t.typeKind, t.treeRepr ] + + # At this point objDefAst should look something like: + # ObjectTy + # Empty + # Sym "BaseObject"" | Empty + # RecList + # Sym "field1" + # Sym "field2" + # ... + + if objDefAst[1].kind == nnkSym: + # We have a base class symbol, let's recurse and try and resolve the fields + # for the base class + for fieldDef in objDefAst[1].fields: result.add(fieldDef) + + for fieldDef in objDefAst[2].children: + # objDefAst[2] is a RecList of + # ignore AST nodes that are not field definitions + if fieldDef.kind == nnkIdentDefs: result.add((fieldDef[0], fieldDef[1])) + elif fieldDef.kind == nnkSym: result.add((fieldDef, fieldDef.getTypeInst)) + else: error "unknown object field definition AST: $#" % $fieldDef.kind + template walkFieldDefs*(t: NimNode, body: untyped) = ## Iterate over every field of the given Nim object, yielding and defining ## `fieldIdent` and `fieldType`, the name of the field as a Nim Ident node ## and the type of the field as a Nim Type node respectively. - let tTypeImpl = t.getTypeImpl + for (fieldIdent {.inject.}, fieldType {.inject.}) in t.fields: body - var nodeToItr: NimNode - if tTypeImpl.typeKind == ntyObject: nodeToItr = tTypeImpl[2] - elif tTypeImpl.typeKind == ntyTypeDesc: nodeToItr = tTypeImpl.getType[1].getType[2] - else: error $t & " is not an object or type desc (it's a " & $tTypeImpl.typeKind & ")." - - for fieldDef {.inject.} in nodeToItr.children: - # ignore AST nodes that are not field definitions - if fieldDef.kind == nnkIdentDefs: - let fieldIdent {.inject.} = fieldDef[0] - let fieldType {.inject.} = fieldDef[1] - body - - elif fieldDef.kind == nnkSym: - let fieldIdent {.inject.} = fieldDef - let fieldType {.inject.} = fieldDef.getType - body +#[ TODO: replace walkFieldDefs with things like this: +func columnNamesForModel*(modelType: typedesc): seq[string] = + modelType.fields.mapIt(identNameToDb($it[0])) +]# macro columnNamesForModel*(modelType: typed): seq[string] = ## Return the column names corresponding to the the fields of the given @@ -317,6 +374,7 @@ macro rowToModel*(modelType: typed, row: seq[string]): untyped = createParseStmt(fieldType, itemLookup))) idx += 1 +#[ macro listFields*(t: typed): untyped = var fields: seq[tuple[n: string, t: string]] = @[] t.walkFieldDefs: @@ -324,6 +382,7 @@ macro listFields*(t: typed): untyped = else: fields.add((n: $fieldIdent, t: $fieldType)) result = newLit(fields) +]# proc typeOfColumn*(modelType: NimNode, colName: string): NimNode = ## Given a model type and a column name, return the Nim type for that column. @@ -370,8 +429,8 @@ macro populateMutateClauses*(t: typed, newRecord: bool, mc: var MutateClauses): # if we're looking at an optional field, add logic to check for presence elif fieldType.kind == nnkBracketExpr and - fieldType.len > 0 and - fieldType[0] == Option.getType: + fieldType.len > 0 and + fieldType[0] == Option.getType: result.add quote do: `mc`.columns.add(identNameToDb(`fieldName`))