Add support for KNN vector similarity search (#513)
Co-authored-by: Chayim <chayim@users.noreply.github.com>
This commit is contained in:
parent
70f64011fd
commit
89b6c84c4a
3 changed files with 190 additions and 9 deletions
|
@ -8,6 +8,8 @@ from .model.model import (
|
||||||
FindQuery,
|
FindQuery,
|
||||||
HashModel,
|
HashModel,
|
||||||
JsonModel,
|
JsonModel,
|
||||||
|
VectorFieldOptions,
|
||||||
|
KNNExpression,
|
||||||
NotFoundError,
|
NotFoundError,
|
||||||
QueryNotSupportedError,
|
QueryNotSupportedError,
|
||||||
QuerySyntaxError,
|
QuerySyntaxError,
|
||||||
|
|
|
@ -4,6 +4,8 @@ from .model import (
|
||||||
Field,
|
Field,
|
||||||
HashModel,
|
HashModel,
|
||||||
JsonModel,
|
JsonModel,
|
||||||
|
VectorFieldOptions,
|
||||||
|
KNNExpression,
|
||||||
NotFoundError,
|
NotFoundError,
|
||||||
RedisModel,
|
RedisModel,
|
||||||
)
|
)
|
||||||
|
|
|
@ -252,6 +252,24 @@ class Expression:
|
||||||
return render_tree(self)
|
return render_tree(self)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class KNNExpression:
|
||||||
|
k: int
|
||||||
|
vector_field: ModelField
|
||||||
|
reference_vector: bytes
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return f"KNN $K @{self.vector_field.name} $knn_ref_vector"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def query_params(self) -> Dict[str, Union[str, bytes]]:
|
||||||
|
return {"K": str(self.k), "knn_ref_vector": self.reference_vector}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def score_field(self) -> str:
|
||||||
|
return f"__{self.vector_field.name}_score"
|
||||||
|
|
||||||
|
|
||||||
ExpressionOrNegated = Union[Expression, NegatedExpression]
|
ExpressionOrNegated = Union[Expression, NegatedExpression]
|
||||||
|
|
||||||
|
|
||||||
|
@ -349,8 +367,9 @@ class FindQuery:
|
||||||
self,
|
self,
|
||||||
expressions: Sequence[ExpressionOrNegated],
|
expressions: Sequence[ExpressionOrNegated],
|
||||||
model: Type["RedisModel"],
|
model: Type["RedisModel"],
|
||||||
|
knn: Optional[KNNExpression] = None,
|
||||||
offset: int = 0,
|
offset: int = 0,
|
||||||
limit: int = DEFAULT_PAGE_SIZE,
|
limit: Optional[int] = None,
|
||||||
page_size: int = DEFAULT_PAGE_SIZE,
|
page_size: int = DEFAULT_PAGE_SIZE,
|
||||||
sort_fields: Optional[List[str]] = None,
|
sort_fields: Optional[List[str]] = None,
|
||||||
nocontent: bool = False,
|
nocontent: bool = False,
|
||||||
|
@ -364,13 +383,16 @@ class FindQuery:
|
||||||
|
|
||||||
self.expressions = expressions
|
self.expressions = expressions
|
||||||
self.model = model
|
self.model = model
|
||||||
|
self.knn = knn
|
||||||
self.offset = offset
|
self.offset = offset
|
||||||
self.limit = limit
|
self.limit = limit or (self.knn.k if self.knn else DEFAULT_PAGE_SIZE)
|
||||||
self.page_size = page_size
|
self.page_size = page_size
|
||||||
self.nocontent = nocontent
|
self.nocontent = nocontent
|
||||||
|
|
||||||
if sort_fields:
|
if sort_fields:
|
||||||
self.sort_fields = self.validate_sort_fields(sort_fields)
|
self.sort_fields = self.validate_sort_fields(sort_fields)
|
||||||
|
elif self.knn:
|
||||||
|
self.sort_fields = [self.knn.score_field]
|
||||||
else:
|
else:
|
||||||
self.sort_fields = []
|
self.sort_fields = []
|
||||||
|
|
||||||
|
@ -425,11 +447,26 @@ class FindQuery:
|
||||||
if self._query:
|
if self._query:
|
||||||
return self._query
|
return self._query
|
||||||
self._query = self.resolve_redisearch_query(self.expression)
|
self._query = self.resolve_redisearch_query(self.expression)
|
||||||
|
if self.knn:
|
||||||
|
self._query = (
|
||||||
|
self._query
|
||||||
|
if self._query.startswith("(") or self._query == "*"
|
||||||
|
else f"({self._query})"
|
||||||
|
) + f"=>[{self.knn}]"
|
||||||
return self._query
|
return self._query
|
||||||
|
|
||||||
|
@property
|
||||||
|
def query_params(self):
|
||||||
|
params: List[Union[str, bytes]] = []
|
||||||
|
if self.knn:
|
||||||
|
params += [attr for kv in self.knn.query_params.items() for attr in kv]
|
||||||
|
return params
|
||||||
|
|
||||||
def validate_sort_fields(self, sort_fields: List[str]):
|
def validate_sort_fields(self, sort_fields: List[str]):
|
||||||
for sort_field in sort_fields:
|
for sort_field in sort_fields:
|
||||||
field_name = sort_field.lstrip("-")
|
field_name = sort_field.lstrip("-")
|
||||||
|
if self.knn and field_name == self.knn.score_field:
|
||||||
|
continue
|
||||||
if field_name not in self.model.__fields__:
|
if field_name not in self.model.__fields__:
|
||||||
raise QueryNotSupportedError(
|
raise QueryNotSupportedError(
|
||||||
f"You tried sort by {field_name}, but that field "
|
f"You tried sort by {field_name}, but that field "
|
||||||
|
@ -728,10 +765,27 @@ class FindQuery:
|
||||||
return result
|
return result
|
||||||
|
|
||||||
async def execute(self, exhaust_results=True, return_raw_result=False):
|
async def execute(self, exhaust_results=True, return_raw_result=False):
|
||||||
args = ["ft.search", self.model.Meta.index_name, self.query, *self.pagination]
|
args: List[Union[str, bytes]] = [
|
||||||
|
"FT.SEARCH",
|
||||||
|
self.model.Meta.index_name,
|
||||||
|
self.query,
|
||||||
|
*self.pagination,
|
||||||
|
]
|
||||||
if self.sort_fields:
|
if self.sort_fields:
|
||||||
args += self.resolve_redisearch_sort_fields()
|
args += self.resolve_redisearch_sort_fields()
|
||||||
|
|
||||||
|
if self.query_params:
|
||||||
|
args += ["PARAMS", str(len(self.query_params))] + self.query_params
|
||||||
|
|
||||||
|
if self.knn:
|
||||||
|
# Ensure DIALECT is at least 2
|
||||||
|
if "DIALECT" not in args:
|
||||||
|
args += ["DIALECT", "2"]
|
||||||
|
else:
|
||||||
|
i_dialect = args.index("DIALECT") + 1
|
||||||
|
if int(args[i_dialect]) < 2:
|
||||||
|
args[i_dialect] = "2"
|
||||||
|
|
||||||
if self.nocontent:
|
if self.nocontent:
|
||||||
args.append("NOCONTENT")
|
args.append("NOCONTENT")
|
||||||
|
|
||||||
|
@ -917,11 +971,13 @@ class FieldInfo(PydanticFieldInfo):
|
||||||
sortable = kwargs.pop("sortable", Undefined)
|
sortable = kwargs.pop("sortable", Undefined)
|
||||||
index = kwargs.pop("index", Undefined)
|
index = kwargs.pop("index", Undefined)
|
||||||
full_text_search = kwargs.pop("full_text_search", Undefined)
|
full_text_search = kwargs.pop("full_text_search", Undefined)
|
||||||
|
vector_options = kwargs.pop("vector_options", None)
|
||||||
super().__init__(default=default, **kwargs)
|
super().__init__(default=default, **kwargs)
|
||||||
self.primary_key = primary_key
|
self.primary_key = primary_key
|
||||||
self.sortable = sortable
|
self.sortable = sortable
|
||||||
self.index = index
|
self.index = index
|
||||||
self.full_text_search = full_text_search
|
self.full_text_search = full_text_search
|
||||||
|
self.vector_options = vector_options
|
||||||
|
|
||||||
|
|
||||||
class RelationshipInfo(Representation):
|
class RelationshipInfo(Representation):
|
||||||
|
@ -935,6 +991,94 @@ class RelationshipInfo(Representation):
|
||||||
self.link_model = link_model
|
self.link_model = link_model
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class VectorFieldOptions:
|
||||||
|
class ALGORITHM(Enum):
|
||||||
|
FLAT = "FLAT"
|
||||||
|
HNSW = "HNSW"
|
||||||
|
|
||||||
|
class TYPE(Enum):
|
||||||
|
FLOAT32 = "FLOAT32"
|
||||||
|
FLOAT64 = "FLOAT64"
|
||||||
|
|
||||||
|
class DISTANCE_METRIC(Enum):
|
||||||
|
L2 = "L2"
|
||||||
|
IP = "IP"
|
||||||
|
COSINE = "COSINE"
|
||||||
|
|
||||||
|
algorithm: ALGORITHM
|
||||||
|
type: TYPE
|
||||||
|
dimension: int
|
||||||
|
distance_metric: DISTANCE_METRIC
|
||||||
|
|
||||||
|
# Common optional parameters
|
||||||
|
initial_cap: Optional[int] = None
|
||||||
|
|
||||||
|
# Optional parameters for FLAT
|
||||||
|
block_size: Optional[int] = None
|
||||||
|
|
||||||
|
# Optional parameters for HNSW
|
||||||
|
m: Optional[int] = None
|
||||||
|
ef_construction: Optional[int] = None
|
||||||
|
ef_runtime: Optional[int] = None
|
||||||
|
epsilon: Optional[float] = None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def flat(
|
||||||
|
type: TYPE,
|
||||||
|
dimension: int,
|
||||||
|
distance_metric: DISTANCE_METRIC,
|
||||||
|
initial_cap: Optional[int] = None,
|
||||||
|
block_size: Optional[int] = None,
|
||||||
|
):
|
||||||
|
return VectorFieldOptions(
|
||||||
|
algorithm=VectorFieldOptions.ALGORITHM.FLAT,
|
||||||
|
type=type,
|
||||||
|
dimension=dimension,
|
||||||
|
distance_metric=distance_metric,
|
||||||
|
initial_cap=initial_cap,
|
||||||
|
block_size=block_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def hnsw(
|
||||||
|
type: TYPE,
|
||||||
|
dimension: int,
|
||||||
|
distance_metric: DISTANCE_METRIC,
|
||||||
|
initial_cap: Optional[int] = None,
|
||||||
|
m: Optional[int] = None,
|
||||||
|
ef_construction: Optional[int] = None,
|
||||||
|
ef_runtime: Optional[int] = None,
|
||||||
|
epsilon: Optional[float] = None,
|
||||||
|
):
|
||||||
|
return VectorFieldOptions(
|
||||||
|
algorithm=VectorFieldOptions.ALGORITHM.HNSW,
|
||||||
|
type=type,
|
||||||
|
dimension=dimension,
|
||||||
|
distance_metric=distance_metric,
|
||||||
|
initial_cap=initial_cap,
|
||||||
|
m=m,
|
||||||
|
ef_construction=ef_construction,
|
||||||
|
ef_runtime=ef_runtime,
|
||||||
|
epsilon=epsilon,
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def schema(self):
|
||||||
|
attr = []
|
||||||
|
for k, v in vars(self).items():
|
||||||
|
if k == "algorithm" or v is None:
|
||||||
|
continue
|
||||||
|
attr.extend(
|
||||||
|
[
|
||||||
|
k.upper() if k != "dimension" else "DIM",
|
||||||
|
str(v) if not isinstance(v, Enum) else v.name,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
return " ".join([f"VECTOR {self.algorithm.name} {len(attr)}"] + attr)
|
||||||
|
|
||||||
|
|
||||||
def Field(
|
def Field(
|
||||||
default: Any = Undefined,
|
default: Any = Undefined,
|
||||||
*,
|
*,
|
||||||
|
@ -964,6 +1108,7 @@ def Field(
|
||||||
sortable: Union[bool, UndefinedType] = Undefined,
|
sortable: Union[bool, UndefinedType] = Undefined,
|
||||||
index: Union[bool, UndefinedType] = Undefined,
|
index: Union[bool, UndefinedType] = Undefined,
|
||||||
full_text_search: Union[bool, UndefinedType] = Undefined,
|
full_text_search: Union[bool, UndefinedType] = Undefined,
|
||||||
|
vector_options: Optional[VectorFieldOptions] = None,
|
||||||
schema_extra: Optional[Dict[str, Any]] = None,
|
schema_extra: Optional[Dict[str, Any]] = None,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
current_schema_extra = schema_extra or {}
|
current_schema_extra = schema_extra or {}
|
||||||
|
@ -991,6 +1136,7 @@ def Field(
|
||||||
sortable=sortable,
|
sortable=sortable,
|
||||||
index=index,
|
index=index,
|
||||||
full_text_search=full_text_search,
|
full_text_search=full_text_search,
|
||||||
|
vector_options=vector_options,
|
||||||
**current_schema_extra,
|
**current_schema_extra,
|
||||||
)
|
)
|
||||||
field_info._validate()
|
field_info._validate()
|
||||||
|
@ -1083,6 +1229,10 @@ class ModelMeta(ModelMetaclass):
|
||||||
new_class._meta.primary_key = PrimaryKey(
|
new_class._meta.primary_key = PrimaryKey(
|
||||||
name=field_name, field=field
|
name=field_name, field=field
|
||||||
)
|
)
|
||||||
|
if field.field_info.vector_options:
|
||||||
|
score_attr = f"_{field_name}_score"
|
||||||
|
setattr(new_class, score_attr, None)
|
||||||
|
new_class.__annotations__[score_attr] = Union[float, None]
|
||||||
|
|
||||||
if not getattr(new_class._meta, "global_key_prefix", None):
|
if not getattr(new_class._meta, "global_key_prefix", None):
|
||||||
new_class._meta.global_key_prefix = getattr(
|
new_class._meta.global_key_prefix = getattr(
|
||||||
|
@ -1216,8 +1366,12 @@ class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta):
|
||||||
return cls._meta.database
|
return cls._meta.database
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def find(cls, *expressions: Union[Any, Expression]) -> FindQuery:
|
def find(
|
||||||
return FindQuery(expressions=expressions, model=cls)
|
cls,
|
||||||
|
*expressions: Union[Any, Expression],
|
||||||
|
knn: Optional[KNNExpression] = None,
|
||||||
|
) -> FindQuery:
|
||||||
|
return FindQuery(expressions=expressions, knn=knn, model=cls)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_redis(cls, res: Any):
|
def from_redis(cls, res: Any):
|
||||||
|
@ -1237,7 +1391,7 @@ class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta):
|
||||||
for i in range(1, len(res), step):
|
for i in range(1, len(res), step):
|
||||||
if res[i + offset] is None:
|
if res[i + offset] is None:
|
||||||
continue
|
continue
|
||||||
fields = dict(
|
fields: Dict[str, str] = dict(
|
||||||
zip(
|
zip(
|
||||||
map(to_string, res[i + offset][::2]),
|
map(to_string, res[i + offset][::2]),
|
||||||
map(to_string, res[i + offset][1::2]),
|
map(to_string, res[i + offset][1::2]),
|
||||||
|
@ -1247,6 +1401,9 @@ class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta):
|
||||||
if fields.get("$"):
|
if fields.get("$"):
|
||||||
json_fields = json.loads(fields.pop("$"))
|
json_fields = json.loads(fields.pop("$"))
|
||||||
doc = cls(**json_fields)
|
doc = cls(**json_fields)
|
||||||
|
for k, v in fields.items():
|
||||||
|
if k.startswith("__") and k.endswith("_score"):
|
||||||
|
setattr(doc, k[1:], float(v))
|
||||||
else:
|
else:
|
||||||
doc = cls(**fields)
|
doc = cls(**fields)
|
||||||
|
|
||||||
|
@ -1474,7 +1631,13 @@ class HashModel(RedisModel, abc.ABC):
|
||||||
embedded_cls = embedded_cls[0]
|
embedded_cls = embedded_cls[0]
|
||||||
schema = cls.schema_for_type(name, embedded_cls, field_info)
|
schema = cls.schema_for_type(name, embedded_cls, field_info)
|
||||||
elif any(issubclass(typ, t) for t in NUMERIC_TYPES):
|
elif any(issubclass(typ, t) for t in NUMERIC_TYPES):
|
||||||
schema = f"{name} NUMERIC"
|
vector_options: Optional[VectorFieldOptions] = getattr(
|
||||||
|
field_info, "vector_options", None
|
||||||
|
)
|
||||||
|
if vector_options:
|
||||||
|
schema = f"{name} {vector_options.schema}"
|
||||||
|
else:
|
||||||
|
schema = f"{name} NUMERIC"
|
||||||
elif issubclass(typ, str):
|
elif issubclass(typ, str):
|
||||||
if getattr(field_info, "full_text_search", False) is True:
|
if getattr(field_info, "full_text_search", False) is True:
|
||||||
schema = (
|
schema = (
|
||||||
|
@ -1623,10 +1786,22 @@ class JsonModel(RedisModel, abc.ABC):
|
||||||
# Not a class, probably a type annotation
|
# Not a class, probably a type annotation
|
||||||
field_is_model = False
|
field_is_model = False
|
||||||
|
|
||||||
|
vector_options: Optional[VectorFieldOptions] = getattr(
|
||||||
|
field_info, "vector_options", None
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
is_vector = vector_options and any(
|
||||||
|
issubclass(get_args(typ)[0], t) for t in NUMERIC_TYPES
|
||||||
|
)
|
||||||
|
except IndexError:
|
||||||
|
raise RedisModelError(
|
||||||
|
f"Vector field '{name}' must be annotated as a container type"
|
||||||
|
)
|
||||||
|
|
||||||
# When we encounter a list or model field, we need to descend
|
# When we encounter a list or model field, we need to descend
|
||||||
# into the values of the list or the fields of the model to
|
# into the values of the list or the fields of the model to
|
||||||
# find any values marked as indexed.
|
# find any values marked as indexed.
|
||||||
if is_container_type:
|
if is_container_type and not is_vector:
|
||||||
field_type = get_origin(typ)
|
field_type = get_origin(typ)
|
||||||
embedded_cls = get_args(typ)
|
embedded_cls = get_args(typ)
|
||||||
if not embedded_cls:
|
if not embedded_cls:
|
||||||
|
@ -1689,7 +1864,9 @@ class JsonModel(RedisModel, abc.ABC):
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO: GEO field
|
# TODO: GEO field
|
||||||
if parent_is_container_type or parent_is_model_in_container:
|
if is_vector and vector_options:
|
||||||
|
schema = f"{path} AS {index_field_name} {vector_options.schema}"
|
||||||
|
elif parent_is_container_type or parent_is_model_in_container:
|
||||||
if typ is not str:
|
if typ is not str:
|
||||||
raise RedisModelError(
|
raise RedisModelError(
|
||||||
"In this Preview release, list and tuple fields can only "
|
"In this Preview release, list and tuple fields can only "
|
||||||
|
|
Loading…
Reference in a new issue