|
| 1 | +import nimpy |
| 2 | +import std/[os, sugar, enumerate, strutils, tables, sequtils, options, times, macros, paths, math] |
| 3 | +import ../[pymodules, utils, pytypes, numpy, dateutils, nimpyext] |
| 4 | + |
| 5 | +const FILTER_KEYS = ["column1", "column2", "criteria", "value1", "value2"] |
| 6 | +const FILTER_OPS = [">", ">=", "==", "<", "<=", "!=", "in"] |
| 7 | +const FILTER_TYPES = ["all", "any"] |
| 8 | + |
| 9 | +type FilterMethods = enum |
| 10 | + FM_GT |
| 11 | + FM_GE |
| 12 | + FM_EQ |
| 13 | + FM_LT |
| 14 | + FM_LE |
| 15 | + FM_NE |
| 16 | + FM_IN |
| 17 | + |
| 18 | +type FilterType = enum |
| 19 | + FT_ALL |
| 20 | + FT_ANY |
| 21 | + |
| 22 | +type ExpressionValue = (Option[string], Option[PY_ObjectND]) |
| 23 | +type Expression = (ExpressionValue, FilterMethods, ExpressionValue) |
| 24 | + |
| 25 | +iterator iterateRows(exprColumns: seq[string], tablePages: Table[string, seq[string]]): seq[PY_ObjectND] = |
| 26 | + var allIters = newSeq[iterator(): PY_ObjectND]() |
| 27 | + var res: seq[PY_ObjectND] = @[] |
| 28 | + var finished = false |
| 29 | + |
| 30 | + proc makeIterable(column: seq[string]): auto = |
| 31 | + return iterator(): auto = |
| 32 | + for v in iterateColumn[PY_ObjectND](column): |
| 33 | + yield v |
| 34 | + |
| 35 | + for column in exprColumns: |
| 36 | + let i = makeIterable(tablePages[column]) |
| 37 | + |
| 38 | + allIters.add(i) |
| 39 | + |
| 40 | + res.add(i()) |
| 41 | + finished = finished or finished(i) |
| 42 | + |
| 43 | + while not finished: |
| 44 | + yield res |
| 45 | + |
| 46 | + res = newSeqOfCap[PY_ObjectND](allIters.len) |
| 47 | + |
| 48 | + for i in allIters: |
| 49 | + res.add(i()) |
| 50 | + finished = finished or finished(i) |
| 51 | + |
| 52 | +proc extractValue(row: seq[PY_ObjectND], exprCols: seq[string], value: ExpressionValue): PY_ObjectND {.inline.} = |
| 53 | + let (col, val) = value |
| 54 | + |
| 55 | + if val.isSome: |
| 56 | + return val.get |
| 57 | + |
| 58 | + let idx = exprCols.find(col.get) |
| 59 | + |
| 60 | + return row[idx] |
| 61 | + |
| 62 | +proc checkExpression(row: seq[PY_ObjectND], exprCols: seq[string], xpr: Expression): bool {.inline.} = |
| 63 | + let (leftXpr, criteria, rightXpr) = xpr |
| 64 | + let left = row.extractValue(exprCols, leftXpr) |
| 65 | + let right = row.extractValue(exprCols, rightXpr) |
| 66 | + let expressed = ( |
| 67 | + case criteria: |
| 68 | + of FM_EQ: left == right |
| 69 | + of FM_NE: left != right |
| 70 | + of FM_GT: left > right |
| 71 | + of FM_GE: left >= right |
| 72 | + of FM_LT: left < right |
| 73 | + of FM_LE: left <= right |
| 74 | + of FM_IN: left.contains(right) |
| 75 | + ) |
| 76 | + |
| 77 | + return expressed |
| 78 | + |
| 79 | +proc checkExpressions(row: seq[PY_ObjectND], exprCols: seq[string], expressions: seq[Expression], filterType: FilterType): bool {.inline.} = |
| 80 | + case filterType: |
| 81 | + of FT_ANY: any(expressions, xpr => row.checkExpression(exprCols, xpr)) |
| 82 | + of FT_ALL: all(expressions, xpr => row.checkExpression(exprCols, xpr)) |
| 83 | + |
| 84 | +proc filter*(table: nimpy.PyObject, pyExpressions: seq[nimpy.PyObject], filterTypeName: string, tqdm: nimpy.PyObject): (nimpy.PyObject, nimpy.PyObject) = |
| 85 | + let m = modules() |
| 86 | + let builtins = m.builtins |
| 87 | + let tablite = m.tablite |
| 88 | + let base = tablite.modules.base |
| 89 | + let Config = tablite.modules.config.classes.Config |
| 90 | + let TableClass = builtins.getType(table) |
| 91 | + |
| 92 | + let filterType = ( |
| 93 | + case filterTypeName.toLower(): |
| 94 | + of "any": FT_ANY |
| 95 | + of "all": FT_ALL |
| 96 | + else: raise newException(ValueError, "invalid filter type '" & filterTypeName & "' expected " & $FILTER_TYPES) |
| 97 | + ) |
| 98 | + |
| 99 | + if pyExpressions.len == 0: |
| 100 | + return (table, TableClass!()) |
| 101 | + |
| 102 | + let columns = collect: (for c in table.columns.keys(): c.to(string)) |
| 103 | + |
| 104 | + if columns.len == 0: |
| 105 | + return (table, TableClass!()) |
| 106 | + |
| 107 | + var exprCols = newSeq[string]() |
| 108 | + |
| 109 | + var tablePages = initTable[string, seq[string]]() |
| 110 | + var passTablePages = initOrderedTable[string, nimpy.PyObject]() |
| 111 | + var failTablePages = initOrderedTable[string, nimpy.PyObject]() |
| 112 | + |
| 113 | + var colLen = 0 |
| 114 | + |
| 115 | + for (i, c) in enumerate(columns): |
| 116 | + let pyCol = table[c] |
| 117 | + let colPages = base.collectPages(pyCol) |
| 118 | + tablePages[c] = colPages |
| 119 | + passTablePages[c] = base.classes.ColumnClass!(pyCol.path) |
| 120 | + failTablePages[c] = base.classes.ColumnClass!(pyCol.path) |
| 121 | + |
| 122 | + if i == 0: |
| 123 | + colLen = getColumnLen(colPages) |
| 124 | + elif colLen != getColumnLen(colPages): |
| 125 | + raise newException(ValueError, "table must have equal number of columns") |
| 126 | + |
| 127 | + template addParam(columnName: string, valueName: string, paramName: string): auto = |
| 128 | + var res {.noInit.}: ExpressionValue |
| 129 | + |
| 130 | + if columnName.contains(expression): |
| 131 | + if valueName.contains(expression): |
| 132 | + raise newException(ValueError, "filter can only take 1 " & paramName & " expr element, got 2") |
| 133 | + |
| 134 | + let c = expression[columnName].to(string) |
| 135 | + |
| 136 | + if c notin columns: |
| 137 | + raise newException(ValueError, "no such column '" & $c & "'in " & $columns) |
| 138 | + |
| 139 | + if c notin exprCols: |
| 140 | + exprCols.add(c) |
| 141 | + |
| 142 | + res = (some(c), none[PY_ObjectND]()) |
| 143 | + elif not valueName.contains(expression): |
| 144 | + raise newException(ValueError, "no " & paramName & " parameter") |
| 145 | + else: |
| 146 | + let pyVal = expression[valueName] |
| 147 | + let pyType = builtins.getTypeName(pyVal) |
| 148 | + let obj: PY_ObjectND = ( |
| 149 | + case pyType |
| 150 | + of "int": newPY_Object(pyVal.to(int)) |
| 151 | + of "float": newPY_Object(pyVal.to(float)) |
| 152 | + of "bool": newPY_Object(pyVal.to(bool)) |
| 153 | + of "str": newPY_Object(pyVal.to(string)) |
| 154 | + of "datetime": newPY_Object(pyDateTime2NimDateTime(pyVal), K_DATETIME) |
| 155 | + of "date": newPY_Object(pyDate2NimDateTime(pyVal), K_DATE) |
| 156 | + of "time": newPY_Object(pyTime2NimDuration(pyVal)) |
| 157 | + else: implement("invalid object type: " & pyType) |
| 158 | + ) |
| 159 | + res = (none[string](), some(obj)) |
| 160 | + |
| 161 | + res |
| 162 | + |
| 163 | + var expressions = newSeq[Expression]() |
| 164 | + |
| 165 | + for expression in pyExpressions: |
| 166 | + if not builtins.isinstance(expression, builtins.classes.DictClass): |
| 167 | + raise newException(KeyError, "expression must be a dict: " & $expression) |
| 168 | + |
| 169 | + if not builtins.getLen(expression) == 3: |
| 170 | + raise newException(ValueError, "expression must be of len 3: " & $builtins.getLen(expression)) |
| 171 | + |
| 172 | + let invalidKeys = collect: |
| 173 | + for pyKey in expression.keys(): |
| 174 | + let key = pyKey.to(string) |
| 175 | + |
| 176 | + if key in FILTER_KEYS: |
| 177 | + continue |
| 178 | + |
| 179 | + key |
| 180 | + |
| 181 | + if invalidKeys.len > 0: |
| 182 | + raise newException(ValueError, "got unknown keys " & $invalidKeys & " expected " & $FILTER_KEYS) |
| 183 | + |
| 184 | + let criteria = expression["criteria"].to(string) |
| 185 | + |
| 186 | + let crit: FilterMethods = ( |
| 187 | + case criteria |
| 188 | + of ">": FM_GT |
| 189 | + of ">=": FM_GE |
| 190 | + of "==": FM_EQ |
| 191 | + of "<": FM_LT |
| 192 | + of "<=": FM_LE |
| 193 | + of "!=": FM_NE |
| 194 | + of "in": FM_IN |
| 195 | + else: raise newException(ValueError, "invalid criteria '" & criteria & "' expected " & $FILTER_OPS) |
| 196 | + ) |
| 197 | + |
| 198 | + let lOpt = addParam("column1", "value1", "left") |
| 199 | + let rOpt = addParam("column2", "value2", "right") |
| 200 | + |
| 201 | + expressions.add((lOpt, crit, rOpt)) |
| 202 | + |
| 203 | + let pageSize = Config.PAGE_SIZE.to(int) |
| 204 | + let workdir = builtins.toStr(Config.workdir) |
| 205 | + let pidir = Config.pid.to(string) |
| 206 | + let basedir = Path(workdir) / Path(pidir) |
| 207 | + let pagedir = basedir / Path("pages") |
| 208 | + |
| 209 | + createDir(string pagedir) |
| 210 | + |
| 211 | + var bitmask = newSeq[bool](pageSize) |
| 212 | + var bitNum = 0 |
| 213 | + var offset = 0 |
| 214 | + |
| 215 | + template dumpPage(columns: seq[string], passColumn: nimpy.PyObject, failColumn: nimpy.PyObject): void = |
| 216 | + var firstPage = 0 |
| 217 | + var currentOffset = 0 |
| 218 | + |
| 219 | + while true: |
| 220 | + var len = getPageLen(columns[firstPage]) |
| 221 | + |
| 222 | + if offset < currentOffset + len: |
| 223 | + break |
| 224 | + |
| 225 | + inc firstPage |
| 226 | + currentOffset = currentOffset + len |
| 227 | + |
| 228 | + var maskOffset = 0 |
| 229 | + |
| 230 | + let indiceOffset = offset - currentOffset |
| 231 | + |
| 232 | + while maskOffset < bitNum: |
| 233 | + let page = readNumpy(columns[firstPage]) |
| 234 | + |
| 235 | + let len = page.len |
| 236 | + let sliceMax = min((bitNum - maskOffset), len) |
| 237 | + let sliceLen = sliceMax - maskOffset |
| 238 | + let slice = maskOffset..<sliceMax |
| 239 | + |
| 240 | + var validIndices = newSeqOfCap[int](sliceLen - (sliceLen shr 2)) |
| 241 | + var invalidIndices = newSeqOfCap[int](sliceLen shr 2) |
| 242 | + |
| 243 | + for (i, m) in enumerate(bitmask[slice]): |
| 244 | + if m: validIndices.add(i + indiceOffset) |
| 245 | + else: invalidIndices.add(i + indiceOffset) |
| 246 | + |
| 247 | + let passPid = base.classes.SimplePageClass.next_id(string basedir).to(string) |
| 248 | + let failPid = base.classes.SimplePageClass.next_id(string basedir).to(string) |
| 249 | + |
| 250 | + let passPath = string(pagedir / Path(passPid & ".npy")) |
| 251 | + let failPath = string(pagedir / Path(failPid & ".npy")) |
| 252 | + |
| 253 | + let passPage = page[validIndices] |
| 254 | + let failPage = page[invalidIndices] |
| 255 | + |
| 256 | + passPage.save(passPath) |
| 257 | + failPage.save(failPath) |
| 258 | + |
| 259 | + let passPagePy = newPyPage(passPage, string basedir, passPid) |
| 260 | + let failPagePy = newPyPage(failPage, string basedir, failPid) |
| 261 | + |
| 262 | + discard passColumn.pages.append(passPagePy) |
| 263 | + discard failColumn.pages.append(failPagePy) |
| 264 | + |
| 265 | + maskOffset = maskOffset + sliceLen |
| 266 | + inc firstPage |
| 267 | + |
| 268 | + template dumpPages(tablePages: Table[string, seq[string]]): void = |
| 269 | + for (key, col) in tablePages.pairs(): |
| 270 | + col.dumpPage(passTablePages[key], failTablePages[key]) |
| 271 | + |
| 272 | + let tableLen = builtins.getLen(table) |
| 273 | + let tqdmLen = int ceil(float(tableLen) / float(pageSize)) |
| 274 | + let TqdmClass = (if isNone(tqdm): m.tqdm.classes.TqdmClass else: tqdm) |
| 275 | + let pbar = TqdmClass!(total: tqdmLen, desc="filter") |
| 276 | + |
| 277 | + for (i, row) in enumerate(exprCols.iterateRows(tablePages)): |
| 278 | + bitmask[bitNum] = row.checkExpressions(exprCols, expressions, filterType) |
| 279 | + |
| 280 | + inc bitNum |
| 281 | + |
| 282 | + if bitNum >= pageSize: |
| 283 | + tablePages.dumpPages() |
| 284 | + offset = offset + bitNum |
| 285 | + bitNum = 0 |
| 286 | + discard pbar.update(1) |
| 287 | + |
| 288 | + if bitNum > 0: |
| 289 | + tablePages.dumpPages() |
| 290 | + discard pbar.update(1) |
| 291 | + |
| 292 | + template makeTable(T: nimpy.PyObject, columns: OrderedTable[string, nimpy.PyObject]): nimpy.PyObject = |
| 293 | + let tbl = T!() |
| 294 | + |
| 295 | + for (k, v) in columns.pairs: |
| 296 | + tbl[k] = v |
| 297 | + |
| 298 | + tbl |
| 299 | + |
| 300 | + let passTable = makeTable(TableClass, passTablePages) |
| 301 | + let failTable = makeTable(TableClass, failTablePages) |
| 302 | + |
| 303 | + discard pbar.close() |
| 304 | + |
| 305 | + return (passTable, failTable) |
| 306 | + |
| 307 | + |
| 308 | +when appType != "lib": |
| 309 | + let m = modules() |
| 310 | + let Config = m.tablite.modules.config.classes.Config |
| 311 | + |
| 312 | + # Config.PAGE_SIZE = 2 |
| 313 | + |
| 314 | + let table = m.tablite.classes.TableClass!({ |
| 315 | + "a": @[1, 2, 3, 4], |
| 316 | + "b": @[10, 20, 30, 40], |
| 317 | + "c": @[4, 4, 4, 4] |
| 318 | + }.toTable) |
| 319 | + let pyExpressions = @[ |
| 320 | + m.builtins.classes.DictClass!(column1: "a", criteria: ">=", value2: 2), |
| 321 | + # m.builtins.classes.DictClass!(column1: "b", criteria: "==", value2: 20), |
| 322 | + m.builtins.classes.DictClass!(column1: "a", criteria: "==", column2: "c"), |
| 323 | + ] |
| 324 | + |
| 325 | + Config.PAGE_SIZE = 2 |
| 326 | + |
| 327 | + let (tblPass, tblFail) = filter(table, pyExpressions, "all", nil) |
| 328 | + |
| 329 | + discard tblPass.show() |
| 330 | + discard tblFail.show() |
0 commit comments