diff --git a/src/fiber_orm.nim b/src/fiber_orm.nim index 02a6b71..a3f936a 100644 --- a/src/fiber_orm.nim +++ b/src/fiber_orm.nim @@ -348,6 +348,8 @@ proc createRecord*[D: DbConnType, T](db: D, rec: T): T = " RETURNING " & columnNamesForModel(rec).join(",") logQuery("createRecord", sqlStmt) + debug(logService.getLogger("fiber_orm/query"), %*{ "values": mc.values }) + let newRow = db.getRow(sql(sqlStmt), mc.values) result = rowToModel(T, newRow) @@ -500,6 +502,91 @@ template findRecordsBy*[D: DbConnType]( else: db.getRow(sql(countStmt), values)[0].parseInt) +template associate*[D: DbConnType, I, J]( + db: D, + joinTableName: string, + rec1: I, + rec2: J): void = + ## Associate two records via a join table. + + let insertStmt = + "INSERT INTO " & joinTableName & + " (" & tableName(I) & "_id, " & tableName(J) & "_id) " & + " VALUES (?, ?)" + + logQuery("associate", insertStmt, [("id1", $rec1.id), ("id2", $rec2.id)]) + db.exec(sql(insertStmt), [$rec1.id, $rec2.id]) + + +template findViaJoinTable*[D: DbConnType, L]( + db: D, + joinTableName: string, + targetType: type, + rec: L, + page: Option[PaginationParams]): untyped = + ## Find all records of `targetType` that are associated with `rec` via a + ## join table. + let columns = columnNamesForModel(targetType).mapIt("t." & it).join(",") + + var fetchStmt = + "SELECT " & columns & + " FROM " & tableName(targetType) & " AS t " & + " JOIN " & joinTableName & " AS j " & + " ON t.id = jt." & tableName(targetType) & "_id " & + " WHERE jt." & tableName(rec) & "_id = ?" + + var countStmt = + "SELECT COUNT(*) FROM " & joinTableName & + " WHERE " & tableName(rec) & "_id = ?" + + if page.isSome: fetchStmt &= getPagingClause(page.get) + + logQuery("findViaJoinTable", fetchStmt, [("id", $rec.id)]) + let records = db.getAllRows(sql(fetchStmt), $rec.id) + .mapIt(rowToModel(targetType, it)) + + PagedRecords[targetType]( + pagination: page, + records: records, + totalRecords: + if page.isNone: records.len + else: db.getRow(sql(countStmt))[0].parseInt) + +template findViaJoinTable*[D: DbConnType]( + db: D, + joinTableName: string, + targetType: type, + lookupType: type, + id: typed, + page: Option[PaginationParams]): untyped = + ## Find all records of `targetType` that are associated with a record of + ## `lookupType` via a join table. + let columns = columnNamesForModel(targetType).mapIt("t." & it).join(",") + + var fetchStmt = + "SELECT " & columns & + " FROM " & tableName(targetType) & " AS t " & + " JOIN " & joinTableName & " AS j " & + " ON t.id = jt." & tableName(targetType) & "_id " & + " WHERE jt." & tableName(lookupType) & "_id = ?" + + var countStmt = + "SELECT COUNT(*) FROM " & joinTableName & + " WHERE " & tableName(lookupType) & "_id = ?" + + if page.isSome: fetchStmt &= getPagingClause(page.get) + + logQuery("findViaJoinTable", fetchStmt, [("id", $id)]) + let records = db.getAllRows(sql(fetchStmt), $id) + .mapIt(rowToModel(targetType, it)) + + PagedRecords[targetType]( + pagination: page, + records: records, + totalRecords: + if page.isNone: records.len + else: db.getRow(sql(countStmt))[0].parseInt) + macro generateProcsForModels*(dbType: type, modelTypes: openarray[type]): untyped = ## Generate all standard access procedures for the given model types. For a ## `model class`_ named `TodoItem`, this will generate the following @@ -656,6 +743,103 @@ macro generateProcsForFieldLookups*(dbType: type, modelsAndFields: openarray[tup result.add procDefAST +macro generateJoinTableProcs*( + dbType, model1Type, model2Type: type, + joinTableName: string): untyped = + ## Generate lookup procedures for a pair of models with a join table. For + ## example, given the TODO database demonstrated above, where `TodoItem` and + ## `TimeEntry` have a many-to-many relationship, you might have a join table + ## `todo_items_time_entries` with columns `todo_item_id` and `time_entry_id`. + ## This macro will generate the following procedures: + ## + ## .. code-block:: Nim + ## proc findTodoItemsByTimeEntry*(db: SampleDB, timeEntry: TimeEntry): seq[TodoItem] + ## proc findTimeEntriesByTodoItem*(db: SampleDB, todoItem: TodoItem): seq[TimeEntry] + ## + ## `dbType` is expected to be some type that has a defined `withConnection` + ## procedure (see `Database Object`_ for details). + ## + ## .. _Database Object: #database-object + result = newStmtList() + + if model1Type.getType[1].typeKind == ntyRef or + model2Type.getType[1].typeKind == ntyRef: + raise newException(ValueError, + "fiber_orm model object must be objects, not refs") + + let model1Name = $(model1Type.getType[1]) + let model2Name = $(model2Type.getType[1]) + let getModel1Name = ident("get" & pluralize(model1Name) & "By" & model2Name) + let getModel2Name = ident("get" & pluralize(model2Name) & "By" & model1Name) + let id1Type = typeOfColumn(model1Type, "id") + let id2Type = typeOfColumn(model2Type, "id") + let joinTableNameNode = newStrLitNode($joinTableName) + + result.add quote do: + proc `getModel1Name`*( + db: `dbType`, + id: `id2Type`, + pagination = none[PaginationParams]()): PagedRecords[`model1Type`] = + db.withConnection conn: + result = findViaJoinTable( + conn, + `joinTableNameNode`, + `model1Type`, + `model2Type`, + id, + pagination) + + proc `getModel1Name`*( + db: `dbType`, + rec: `model2Type`, + pagination = none[PaginationParams]()): PagedRecords[`model1Type`] = + db.withConnection conn: + result = findViaJoinTable( + conn, + `joinTableNameNode`, + `model1Type`, + rec, + pagination) + + proc `getModel2Name`*( + db: `dbType`, + id: `id1Type`, + pagination = none[PaginationParams]()): Pagedrecords[`model2Type`] = + db.withConnection conn: + result = findViaJoinTable( + conn, + `joinTableNameNode`, + `model2Type`, + `model1Type`, + id, + pagination) + + proc `getModel2Name`*( + db: `dbType`, + rec: `model1Type`, + pagination = none[PaginationParams]()): Pagedrecords[`model2Type`] = + db.withConnection conn: + result = findViaJoinTable( + conn, + `joinTableNameNode`, + `model2Type`, + rec, + pagination) + + proc associate*( + db: `dbType`, + rec1: `model1Type`, + rec2: `model2Type`): void = + db.withConnection conn: + associate(conn, `joinTableNameNode`, rec1, rec2) + + proc associate*( + db: `dbType`, + rec2: `model2Type`, + rec1: `model1Type`): void = + db.withConnection conn: + associate(conn, `joinTableNameNode`, rec1, rec2) + proc initPool*[D: DbConnType]( connect: proc(): D, poolSize = 10, diff --git a/src/fiber_orm/util.nim b/src/fiber_orm/util.nim index d4e04cf..e397252 100644 --- a/src/fiber_orm/util.nim +++ b/src/fiber_orm/util.nim @@ -217,7 +217,7 @@ proc parseDbArray*(val: string): seq[string] = result.add(curStr) func createParseStmt*(t, value: NimNode): NimNode = - ## Utility method to create the Nim cod required to parse a value coming from + ## Utility method to create the Nim code required to parse a value coming from ## the a database query. This is used by functions like `rowToModel` to parse ## the dataabase columns into the Nim object fields. @@ -240,7 +240,7 @@ func createParseStmt*(t, value: NimNode): NimNode = elif t.getType == DateTime.getType: result = quote do: parsePGDatetime(`value`) - else: error "Unknown value object type: " & $t.getTypeInst + else: error "Cannot parse column with unknown object type: " & $t.getTypeInst elif t.typeKind == ntyGenericInst: @@ -254,7 +254,7 @@ func createParseStmt*(t, value: NimNode): NimNode = if `value`.len == 0: none[`innerType`]() else: some(`parseStmt`) - else: error "Unknown generic instance type: " & $t.getTypeInst + else: error "Cannot parse column with unknown generic instance type: " & $t.getTypeInst elif t.typeKind == ntyRef: @@ -262,7 +262,7 @@ func createParseStmt*(t, value: NimNode): NimNode = result = quote do: parseJson(`value`) else: - error "Unknown ref type: " & $t.getTypeInst + error "Cannot parse column with unknown ref type: " & $t.getTypeInst elif t.typeKind == ntySequence: let innerType = t[1] @@ -281,14 +281,14 @@ func createParseStmt*(t, value: NimNode): NimNode = result = quote do: parseFloat(`value`) elif t.typeKind == ntyBool: - result = quote do: "true".startsWith(`value`.toLower) + result = quote do: "true".startsWith(`value`.toLower) or `value` == "1" elif t.typeKind == ntyEnum: let innerType = t.getTypeInst result = quote do: parseEnum[`innerType`](`value`) else: - error "Unknown value type: " & $t.typeKind + error "Cannot parse column with unknown value type: " & $t.typeKind func fields(t: NimNode): seq[tuple[fieldIdent: NimNode, fieldType: NimNode]] = #[