Handle TAG queries that include the separator

This commit is contained in:
Andrew Brookins 2021-10-05 16:40:02 -07:00
parent b46408ccd2
commit 8f32b359f0
7 changed files with 591 additions and 128 deletions

View file

@ -1,9 +1,9 @@
import abc
import dataclasses
import decimal
import json
import logging
import operator
import re
from copy import deepcopy
from enum import Enum
from functools import reduce
@ -22,10 +22,9 @@ from typing import (
no_type_check,
Protocol,
List,
Type,
Pattern, get_origin, get_args
get_origin,
get_args, Type
)
import uuid
import redis
from pydantic import BaseModel, validator
@ -34,46 +33,43 @@ from pydantic.fields import ModelField, Undefined, UndefinedType
from pydantic.main import ModelMetaclass
from pydantic.typing import NoArgAnyCallable, resolve_annotations
from pydantic.utils import Representation
from ulid import ULID
from .encoders import jsonable_encoder
from .render_tree import render_tree
from .token_escaper import TokenEscaper
model_registry = {}
_T = TypeVar("_T")
log = logging.getLogger(__name__)
class TokenEscaper:
"""
Escape punctuation within an input string.
"""
# Characters that RediSearch requires us to escape during queries.
# Source: https://oss.redis.com/redisearch/Escaping/#the_rules_of_text_field_tokenization
DEFAULT_ESCAPED_CHARS = r"[,.<>{}\[\]\\\"\':;!@#$%^&*()\-+=~\ ]"
def __init__(self, escape_chars_re: Optional[Pattern] = None):
if escape_chars_re:
self.escaped_chars_re = escape_chars_re
else:
self.escaped_chars_re = re.compile(self.DEFAULT_ESCAPED_CHARS)
def escape(self, string):
def escape_symbol(match):
value = match.group(0)
return f"\\{value}"
return self.escaped_chars_re.sub(escape_symbol, string)
escaper = TokenEscaper()
# For basic exact-match field types like an indexed string, we create a TAG
# field in the RediSearch index. TAG is designed for multi-value fields
# separated by a "separator" character. We're using the field for single values
# (multi-value TAGs will be exposed as a separate field type), and we use the
# pipe character (|) as the separator. There is no way to escape this character
# in hash fields or JSON objects, so if someone indexes a value that includes
# the pipe, we'll warn but allow, and then warn again if they try to query for
# values that contain this separator.
SINGLE_VALUE_TAG_FIELD_SEPARATOR = "|"
# This is the default field separator in RediSearch. We need it to determine if
# someone has accidentally passed in the field separator with string value of a
# multi-value field lookup, like a IN or NOT_IN.
DEFAULT_REDISEARCH_FIELD_SEPARATOR = ","
class RedisModelError(Exception):
pass
"""Raised when a problem exists in the definition of a RedisModel."""
class QuerySyntaxError(Exception):
"""Raised when a query is constructed improperly."""
class NotFoundError(Exception):
"""A query found no results."""
"""Raised when a query found no results."""
class Operators(Enum):
@ -91,9 +87,45 @@ class Operators(Enum):
LIKE = 12
ALL = 13
def __str__(self):
return str(self.name)
ExpressionOrModelField = Union['Expression', 'NegatedExpression', ModelField]
class ExpressionProtocol(Protocol):
op: Operators
left: ExpressionOrModelField
right: ExpressionOrModelField
def __invert__(self) -> 'Expression':
pass
def __and__(self, other: ExpressionOrModelField):
pass
def __or__(self, other: ExpressionOrModelField):
pass
@property
def name(self) -> str:
raise NotImplementedError
@property
def tree(self) -> str:
raise NotImplementedError
@dataclasses.dataclass
class NegatedExpression:
"""A negated Expression object.
For now, this is a separate dataclass from Expression that acts as a facade
around an Expression, indicating to model code (specifically, code
responsible for querying) to negate the logic in the wrapped Expression. A
better design is probably possible, maybe at least an ExpressionProtocol?
"""
expression: 'Expression'
def __invert__(self):
@ -105,22 +137,53 @@ class NegatedExpression:
def __or__(self, other):
return Expression(left=self, op=Operators.OR, right=other)
@property
def left(self):
return self.expression.left
@property
def right(self):
return self.expression.right
@property
def op(self):
return self.expression.op
@property
def name(self):
if self.expression.op is Operators.EQ:
return f"NOT {self.expression.name}"
else:
return f"{self.expression.name} NOT"
@property
def tree(self):
return render_tree(self)
@dataclasses.dataclass
class Expression:
op: Operators
left: Any
right: Any
left: ExpressionOrModelField
right: ExpressionOrModelField
def __invert__(self):
return NegatedExpression(self)
def __and__(self, other):
def __and__(self, other: ExpressionOrModelField):
return Expression(left=self, op=Operators.AND, right=other)
def __or__(self, other):
def __or__(self, other: ExpressionOrModelField):
return Expression(left=self, op=Operators.OR, right=other)
@property
def name(self):
return str(self.op)
@property
def tree(self):
return render_tree(self)
ExpressionOrNegated = Union[Expression, NegatedExpression]
@ -129,22 +192,22 @@ class ExpressionProxy:
def __init__(self, field: ModelField):
self.field = field
def __eq__(self, other: Any) -> Expression:
def __eq__(self, other: Any) -> Expression: # type: ignore[override]
return Expression(left=self.field, op=Operators.EQ, right=other)
def __ne__(self, other: Any) -> Expression:
def __ne__(self, other: Any) -> Expression: # type: ignore[override]
return Expression(left=self.field, op=Operators.NE, right=other)
def __lt__(self, other: Any) -> Expression:
def __lt__(self, other: Any) -> Expression: # type: ignore[override]
return Expression(left=self.field, op=Operators.LT, right=other)
def __le__(self, other: Any) -> Expression:
def __le__(self, other: Any) -> Expression: # type: ignore[override]
return Expression(left=self.field, op=Operators.LE, right=other)
def __gt__(self, other: Any) -> Expression:
def __gt__(self, other: Any) -> Expression: # type: ignore[override]
return Expression(left=self.field, op=Operators.GT, right=other)
def __ge__(self, other: Any) -> Expression:
def __ge__(self, other: Any) -> Expression: # type: ignore[override]
return Expression(left=self.field, op=Operators.GE, right=other)
@ -184,9 +247,9 @@ class FindQuery:
self.sort_fields = []
self._expression = None
self._query = None
self._pagination = []
self._model_cache = []
self._query: Optional[str] = None
self._pagination: list[str] = []
self._model_cache: list[RedisModel] = []
@property
def pagination(self):
@ -236,24 +299,24 @@ class FindQuery:
else:
# TAG fields are the default field type.
# TODO: A ListField or ArrayField that supports multiple values
# and contains logic.
# and contains logic should allow IN and NOT_IN queries.
return RediSearchFieldTypes.TAG
@staticmethod
def expand_tag_value(value):
err = RedisModelError(f"Using the IN operator requires passing a sequence of "
"possible values. You passed: {value}")
if isinstance(str, value):
raise err
return value
try:
expanded_value = "|".join([escaper.escape(v) for v in value])
except TypeError:
raise err
raise QuerySyntaxError("Values passed to an IN query must be iterables,"
"like a list of strings. For more information, see:"
"TODO: doc.")
return expanded_value
@classmethod
def resolve_value(cls, field_name: str, field_type: RediSearchFieldTypes,
op: Operators, value: Any) -> str:
field_info: PydanticFieldInfo, op: Operators, value: Any) -> str:
result = ""
if field_type is RediSearchFieldTypes.TEXT:
result = f"@{field_name}:"
@ -282,17 +345,41 @@ class FindQuery:
result += f"@{field_name}:[{value} +inf]"
elif op is Operators.LE:
result += f"@{field_name}:[-inf {value}]"
# TODO: How will we know the difference between a multi-value use of a TAG
# field and our hidden use of TAG for exact-match queries?
elif field_type is RediSearchFieldTypes.TAG:
if op is Operators.EQ:
value = escaper.escape(value)
result += f"@{field_name}:{{{value}}}"
separator_char = getattr(field_info, 'separator',
SINGLE_VALUE_TAG_FIELD_SEPARATOR)
if value == separator_char:
# The value is ONLY the TAG field separator character --
# this is not going to work.
log.warning("Your query against the field %s is for a single character, %s, "
"that is used internally by redis-developer-python. We must ignore "
"this portion of the query. Please review your query to find "
"an alternative query that uses a string containing more than "
"just the character %s.", field_name, separator_char, separator_char)
return ""
if separator_char in value:
# The value contains the TAG field separator. We can work
# around this by breaking apart the values and unioning them
# with multiple field:{} queries.
values = filter(None, value.split(separator_char))
for value in values:
value = escaper.escape(value)
result += f"@{field_name}:{{{value}}}"
else:
value = escaper.escape(value)
result += f"@{field_name}:{{{value}}}"
elif op is Operators.NE:
value = escaper.escape(value)
result += f"-(@{field_name}:{{{value}}})"
elif op is Operators.IN:
# TODO: Implement IN, test this...
expanded_value = cls.expand_tag_value(value)
result += f"(@{field_name}:{{{expanded_value}}})"
elif op is Operators.NOT_IN:
# TODO: Implement NOT_IN, test this...
expanded_value = cls.expand_tag_value(value)
result += f"-(@{field_name}:{{{expanded_value}}})"
@ -314,10 +401,11 @@ class FindQuery:
return ["SORTBY", *fields]
@classmethod
def resolve_redisearch_query(cls, expression: ExpressionOrNegated):
def resolve_redisearch_query(cls, expression: ExpressionOrNegated) -> str:
"""Resolve an expression to a string RediSearch query."""
field_type = None
field_name = None
field_info = None
encompassing_expression_is_negated = False
result = ""
@ -328,7 +416,7 @@ class FindQuery:
if expression.op is Operators.ALL:
if encompassing_expression_is_negated:
# TODO: Is there a use case for this, perhaps for dynamic
# scoring purposes?
# scoring purposes?
raise QueryNotSupportedError("You cannot negate a query for all results.")
return "*"
@ -338,6 +426,7 @@ class FindQuery:
elif isinstance(expression.left, ModelField):
field_type = cls.resolve_field_type(expression.left)
field_name = expression.left.name
field_info = expression.left.field_info
else:
raise QueryNotSupportedError(f"A query expression should start with either a field "
f"or an expression enclosed in parenthesis. See docs: "
@ -365,8 +454,7 @@ class FindQuery:
if isinstance(right, ModelField):
raise QueryNotSupportedError("Comparing fields is not supported. See docs: TODO")
else:
# TODO: Optionals causing IDE errors here
result += cls.resolve_value(field_name, field_type, expression.op, right)
result += cls.resolve_value(field_name, field_type, field_info, expression.op, right)
if encompassing_expression_is_negated:
result = f"-({result})"
@ -416,7 +504,10 @@ class FindQuery:
def first(self):
query = FindQuery(expressions=self.expressions, model=self.model,
offset=0, limit=1, sort_fields=self.sort_fields)
return query.execute()[0]
results = query.execute()
if not results:
raise NotFoundError()
return results[0]
def all(self, batch_size=10):
if batch_size != self.page_size:
@ -494,9 +585,13 @@ class PrimaryKeyCreator(Protocol):
"""Create a new primary key"""
class Uuid4PrimaryKey:
def create_pk(self, *args, **kwargs) -> str:
return str(uuid.uuid4())
class UlidPrimaryKey:
"""A client-side generated primary key that follows the ULID spec.
https://github.com/ulid/javascript#specification
"""
@staticmethod
def create_pk(*args, **kwargs) -> str:
return str(ULID())
def __dataclass_transform__(
@ -601,8 +696,24 @@ class PrimaryKey:
field: ModelField
class MetaProtocol(Protocol):
global_key_prefix: str
model_key_prefix: str
primary_key_pattern: str
database: redis.Redis
primary_key: PrimaryKey
primary_key_creator_cls: Type[PrimaryKeyCreator]
index_name: str
abstract: bool
@dataclasses.dataclass
class DefaultMeta:
# TODO: Should this really be optional here?
"""A default placeholder Meta object.
TODO: Revisit whether this is really necessary, and whether making
these all optional here is the right choice.
"""
global_key_prefix: Optional[str] = None
model_key_prefix: Optional[str] = None
primary_key_pattern: Optional[str] = None
@ -614,6 +725,8 @@ class DefaultMeta:
class ModelMeta(ModelMetaclass):
_meta: MetaProtocol
def __new__(cls, name, bases, attrs, **kwargs): # noqa C901
meta = attrs.pop('Meta', None)
new_class = super().__new__(cls, name, bases, attrs, **kwargs)
@ -656,11 +769,16 @@ class ModelMeta(ModelMetaclass):
redis.Redis(decode_responses=True))
if not getattr(new_class._meta, 'primary_key_creator_cls', None):
new_class._meta.primary_key_creator_cls = getattr(base_meta, "primary_key_creator_cls",
Uuid4PrimaryKey)
UlidPrimaryKey)
if not getattr(new_class._meta, 'index_name', None):
new_class._meta.index_name = f"{new_class._meta.global_key_prefix}:" \
f"{new_class._meta.model_key_prefix}:index"
# Not an abstract model class
if abc.ABC not in bases:
key = f"{new_class.__module__}.{new_class.__qualname__}"
model_registry[key] = new_class
return new_class
@ -680,12 +798,8 @@ class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta):
__pydantic_self__.validate_primary_key()
def __lt__(self, other):
"""Default sort: compare all shared model fields."""
my_keys = set(self.__fields__.keys())
other_keys = set(other.__fields__.keys())
shared_keys = list(my_keys & other_keys)
lt = [getattr(self, k) < getattr(other, k) for k in shared_keys]
return len(lt) > len(shared_keys) / 2
"""Default sort: compare primary key of models."""
return self.pk < other.pk
@validator("pk", always=True)
def validate_pk(cls, v):
@ -726,7 +840,7 @@ class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta):
return cls._meta.database
@classmethod
def find(cls, *expressions: Union[Any, Expression]): # TODO: How to type annotate this?
def find(cls, *expressions: Union[Any, Expression]) -> FindQuery: # TODO: How to type annotate this?
return FindQuery(expressions=expressions, model=cls)
@classmethod
@ -760,7 +874,17 @@ class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta):
except KeyError:
pass
doc = cls(**fields)
try:
fields['json'] = fields['$']
del fields['$']
except KeyError:
pass
if 'json' in fields:
json_fields = json.loads(fields['json'])
doc = cls(**json_fields)
else:
doc = cls(**fields)
docs.append(doc)
return docs
@ -847,7 +971,7 @@ class HashModel(RedisModel, abc.ABC):
_type = field.outer_type_
if getattr(field.field_info, 'primary_key', None):
if issubclass(_type, str):
redisearch_field = f"{name} TAG"
redisearch_field = f"{name} TAG SEPARATOR {SINGLE_VALUE_TAG_FIELD_SEPARATOR}"
else:
redisearch_field = cls.schema_for_type(name, _type, field.field_info)
schema_parts.append(redisearch_field)
@ -872,7 +996,7 @@ class HashModel(RedisModel, abc.ABC):
return schema_parts
@classmethod
def schema_for_type(cls, name, typ: Type, field_info: FieldInfo):
def schema_for_type(cls, name, typ: Any, field_info: PydanticFieldInfo):
if get_origin(typ) == list:
embedded_cls = get_args(typ)
if not embedded_cls:
@ -885,9 +1009,10 @@ class HashModel(RedisModel, abc.ABC):
return f"{name} NUMERIC"
elif issubclass(typ, str):
if getattr(field_info, 'full_text_search', False) is True:
return f"{name} TAG {name}_fts TEXT"
return f"{name} TAG SEPARATOR {SINGLE_VALUE_TAG_FIELD_SEPARATOR} " \
f"{name}_fts TEXT"
else:
return f"{name} TAG"
return f"{name} TAG SEPARATOR {SINGLE_VALUE_TAG_FIELD_SEPARATOR}"
elif issubclass(typ, RedisModel):
sub_fields = []
for embedded_name, field in typ.__fields__.items():
@ -895,7 +1020,7 @@ class HashModel(RedisModel, abc.ABC):
field.field_info))
return " ".join(sub_fields)
else:
return f"{name} TAG"
return f"{name} TAG SEPARATOR {SINGLE_VALUE_TAG_FIELD_SEPARATOR}"
class JsonModel(RedisModel, abc.ABC):
@ -927,7 +1052,7 @@ class JsonModel(RedisModel, abc.ABC):
_type = field.outer_type_
if getattr(field.field_info, 'primary_key', None):
if issubclass(_type, str):
redisearch_field = f"{json_path}.{name} AS {name} TAG"
redisearch_field = f"{json_path}.{name} AS {name} TAG SEPARATOR {SINGLE_VALUE_TAG_FIELD_SEPARATOR}"
else:
redisearch_field = cls.schema_for_type(f"{json_path}.{name}", name, "", _type, field.field_info)
schema_parts.append(redisearch_field)
@ -957,8 +1082,8 @@ class JsonModel(RedisModel, abc.ABC):
# find it in the JSON document, AND the name of the field as it should
# be in the redisearch schema (address_address_line_1). Maybe both "name"
# and "name_prefix"?
def schema_for_type(cls, json_path: str, name: str, name_prefix: str, typ: Type,
field_info: FieldInfo) -> str:
def schema_for_type(cls, json_path: str, name: str, name_prefix: str, typ: Any,
field_info: PydanticFieldInfo) -> str:
index_field_name = f"{name_prefix}{name}"
should_index = getattr(field_info, 'index', False)
@ -986,10 +1111,10 @@ class JsonModel(RedisModel, abc.ABC):
return f"{json_path} AS {index_field_name} NUMERIC"
elif issubclass(typ, str):
if getattr(field_info, 'full_text_search', False) is True:
return f"{json_path} AS {index_field_name} TAG " \
return f"{json_path} AS {index_field_name} TAG SEPARATOR {SINGLE_VALUE_TAG_FIELD_SEPARATOR} " \
f"{json_path} AS {index_field_name}_fts TEXT"
else:
return f"{json_path} AS {index_field_name} TAG"
return f"{json_path} AS {index_field_name} TAG SEPARATOR {SINGLE_VALUE_TAG_FIELD_SEPARATOR}"
else:
return f"{json_path} AS {index_field_name} TAG"

View file

@ -0,0 +1,59 @@
"""
This code adapted from the library "pptree," Copyright (c) 2017 Clément Michard
and released under the MIT license: https://github.com/clemtoy/pptree
"""
import io
def render_tree(current_node, nameattr='name', left_child='left',
right_child='right', indent='', last='updown',
buffer=None):
"""Print a tree-like structure, `current_node`.
This is a mostly-direct-copy of the print_tree() function from the ppbtree
module of the pptree library, but instead of printing to standard out, we
write to a StringIO buffer, then use that buffer to accumulate written lines
during recursive calls to render_tree().
"""
if buffer is None:
buffer = io.StringIO()
if hasattr(current_node, nameattr):
name = lambda node: getattr(node, nameattr)
else:
name = lambda node: str(node)
up = getattr(current_node, left_child, None)
down = getattr(current_node, right_child, None)
if up is not None:
next_last = 'up'
next_indent = '{0}{1}{2}'.format(indent, ' ' if 'up' in last else '|', ' ' * len(str(name(current_node))))
render_tree(up, nameattr, left_child, right_child, next_indent, next_last, buffer)
if last == 'up':
start_shape = ''
elif last == 'down':
start_shape = ''
elif last == 'updown':
start_shape = ' '
else:
start_shape = ''
if up is not None and down is not None:
end_shape = ''
elif up:
end_shape = ''
elif down:
end_shape = ''
else:
end_shape = ''
print('{0}{1}{2}{3}'.format(indent, start_shape, name(current_node), end_shape),
file=buffer)
if down is not None:
next_last = 'down'
next_indent = '{0}{1}{2}'.format(indent, ' ' if 'down' in last else '|', ' ' * len(str(name(current_node))))
render_tree(down, nameattr, left_child, right_child, next_indent, next_last, buffer)
return f"\n{buffer.getvalue()}"

View file

@ -0,0 +1,24 @@
import re
from typing import Optional, Pattern
class TokenEscaper:
"""
Escape punctuation within an input string.
"""
# Characters that RediSearch requires us to escape during queries.
# Source: https://oss.redis.com/redisearch/Escaping/#the_rules_of_text_field_tokenization
DEFAULT_ESCAPED_CHARS = r"[,.<>{}\[\]\\\"\':;!@#$%^&*()\-+=~\ ]"
def __init__(self, escape_chars_re: Optional[Pattern] = None):
if escape_chars_re:
self.escaped_chars_re = escape_chars_re
else:
self.escaped_chars_re = re.compile(self.DEFAULT_ESCAPED_CHARS)
def escape(self, string: str) -> str:
def escape_symbol(match):
value = match.group(0)
return f"\\{value}"
return self.escaped_chars_re.sub(escape_symbol, string)