2023-11-09 18:47:11 +01:00

35 lines
1.2 KiB
Python

from django.db.models.fields import Field
from django.db.models.lookups import In
from django.db.models import Lookup
from itertools import accumulate
@Field.register_lookup
class AncestorLevelsOf(In):
'''Find ancestors based on a level/path string in format of ie: 1_1_123_4'''
lookup_name = 'ancestorsof'
def get_prep_lookup(self):
'''
This function gets called before as_sql() and returns a value to be assigned as self.rhs.
Split the rhs input string by "_", and returns list of all possible ancestor paths.
'''
levels = self.rhs.split("_")
return list(accumulate(levels, func=lambda *i: "_".join(i)))
@Field.register_lookup
class SiblingLevelsOf(Lookup):
'''Find silbings based on a level/path string in format of ie: 1_1_123_4'''
lookup_name = 'siblingsof'
def get_prep_lookup(self):
'''Change the rhs to parent level'''
nodes = self.rhs.split("_")
return "^" + "_".join(nodes[:-1]) + "\_[^\_]+$"
def as_postgresql(self, compiler, connection):
lhs, lhs_params = self.process_lhs(compiler, connection)
rhs, rhs_params = self.process_rhs(compiler, connection)
params = lhs_params + rhs_params
return f"{lhs} ~* {rhs}", params