From 89b6c84c4a358ba12e26f55dbdb4220c6b8f5241 Mon Sep 17 00:00:00 2001 From: Reinier van der Leer Date: Wed, 12 Jul 2023 10:03:53 +0200 Subject: [PATCH] Add support for KNN vector similarity search (#513) Co-authored-by: Chayim --- aredis_om/__init__.py | 2 + aredis_om/model/__init__.py | 2 + aredis_om/model/model.py | 195 ++++++++++++++++++++++++++++++++++-- 3 files changed, 190 insertions(+), 9 deletions(-) diff --git a/aredis_om/__init__.py b/aredis_om/__init__.py index 7b17c79..7aa699d 100644 --- a/aredis_om/__init__.py +++ b/aredis_om/__init__.py @@ -8,6 +8,8 @@ from .model.model import ( FindQuery, HashModel, JsonModel, + VectorFieldOptions, + KNNExpression, NotFoundError, QueryNotSupportedError, QuerySyntaxError, diff --git a/aredis_om/model/__init__.py b/aredis_om/model/__init__.py index d719047..b9ecf36 100644 --- a/aredis_om/model/__init__.py +++ b/aredis_om/model/__init__.py @@ -4,6 +4,8 @@ from .model import ( Field, HashModel, JsonModel, + VectorFieldOptions, + KNNExpression, NotFoundError, RedisModel, ) diff --git a/aredis_om/model/model.py b/aredis_om/model/model.py index 73bbf2a..609ce1d 100644 --- a/aredis_om/model/model.py +++ b/aredis_om/model/model.py @@ -252,6 +252,24 @@ class Expression: return render_tree(self) +@dataclasses.dataclass +class KNNExpression: + k: int + vector_field: ModelField + reference_vector: bytes + + def __str__(self): + return f"KNN $K @{self.vector_field.name} $knn_ref_vector" + + @property + def query_params(self) -> Dict[str, Union[str, bytes]]: + return {"K": str(self.k), "knn_ref_vector": self.reference_vector} + + @property + def score_field(self) -> str: + return f"__{self.vector_field.name}_score" + + ExpressionOrNegated = Union[Expression, NegatedExpression] @@ -349,8 +367,9 @@ class FindQuery: self, expressions: Sequence[ExpressionOrNegated], model: Type["RedisModel"], + knn: Optional[KNNExpression] = None, offset: int = 0, - limit: int = DEFAULT_PAGE_SIZE, + limit: Optional[int] = None, page_size: int = DEFAULT_PAGE_SIZE, sort_fields: Optional[List[str]] = None, nocontent: bool = False, @@ -364,13 +383,16 @@ class FindQuery: self.expressions = expressions self.model = model + self.knn = knn self.offset = offset - self.limit = limit + self.limit = limit or (self.knn.k if self.knn else DEFAULT_PAGE_SIZE) self.page_size = page_size self.nocontent = nocontent if sort_fields: self.sort_fields = self.validate_sort_fields(sort_fields) + elif self.knn: + self.sort_fields = [self.knn.score_field] else: self.sort_fields = [] @@ -425,11 +447,26 @@ class FindQuery: if self._query: return self._query self._query = self.resolve_redisearch_query(self.expression) + if self.knn: + self._query = ( + self._query + if self._query.startswith("(") or self._query == "*" + else f"({self._query})" + ) + f"=>[{self.knn}]" return self._query + @property + def query_params(self): + params: List[Union[str, bytes]] = [] + if self.knn: + params += [attr for kv in self.knn.query_params.items() for attr in kv] + return params + def validate_sort_fields(self, sort_fields: List[str]): for sort_field in sort_fields: field_name = sort_field.lstrip("-") + if self.knn and field_name == self.knn.score_field: + continue if field_name not in self.model.__fields__: raise QueryNotSupportedError( f"You tried sort by {field_name}, but that field " @@ -728,10 +765,27 @@ class FindQuery: return result async def execute(self, exhaust_results=True, return_raw_result=False): - args = ["ft.search", self.model.Meta.index_name, self.query, *self.pagination] + args: List[Union[str, bytes]] = [ + "FT.SEARCH", + self.model.Meta.index_name, + self.query, + *self.pagination, + ] if self.sort_fields: args += self.resolve_redisearch_sort_fields() + if self.query_params: + args += ["PARAMS", str(len(self.query_params))] + self.query_params + + if self.knn: + # Ensure DIALECT is at least 2 + if "DIALECT" not in args: + args += ["DIALECT", "2"] + else: + i_dialect = args.index("DIALECT") + 1 + if int(args[i_dialect]) < 2: + args[i_dialect] = "2" + if self.nocontent: args.append("NOCONTENT") @@ -917,11 +971,13 @@ class FieldInfo(PydanticFieldInfo): sortable = kwargs.pop("sortable", Undefined) index = kwargs.pop("index", Undefined) full_text_search = kwargs.pop("full_text_search", Undefined) + vector_options = kwargs.pop("vector_options", None) super().__init__(default=default, **kwargs) self.primary_key = primary_key self.sortable = sortable self.index = index self.full_text_search = full_text_search + self.vector_options = vector_options class RelationshipInfo(Representation): @@ -935,6 +991,94 @@ class RelationshipInfo(Representation): self.link_model = link_model +@dataclasses.dataclass +class VectorFieldOptions: + class ALGORITHM(Enum): + FLAT = "FLAT" + HNSW = "HNSW" + + class TYPE(Enum): + FLOAT32 = "FLOAT32" + FLOAT64 = "FLOAT64" + + class DISTANCE_METRIC(Enum): + L2 = "L2" + IP = "IP" + COSINE = "COSINE" + + algorithm: ALGORITHM + type: TYPE + dimension: int + distance_metric: DISTANCE_METRIC + + # Common optional parameters + initial_cap: Optional[int] = None + + # Optional parameters for FLAT + block_size: Optional[int] = None + + # Optional parameters for HNSW + m: Optional[int] = None + ef_construction: Optional[int] = None + ef_runtime: Optional[int] = None + epsilon: Optional[float] = None + + @staticmethod + def flat( + type: TYPE, + dimension: int, + distance_metric: DISTANCE_METRIC, + initial_cap: Optional[int] = None, + block_size: Optional[int] = None, + ): + return VectorFieldOptions( + algorithm=VectorFieldOptions.ALGORITHM.FLAT, + type=type, + dimension=dimension, + distance_metric=distance_metric, + initial_cap=initial_cap, + block_size=block_size, + ) + + @staticmethod + def hnsw( + type: TYPE, + dimension: int, + distance_metric: DISTANCE_METRIC, + initial_cap: Optional[int] = None, + m: Optional[int] = None, + ef_construction: Optional[int] = None, + ef_runtime: Optional[int] = None, + epsilon: Optional[float] = None, + ): + return VectorFieldOptions( + algorithm=VectorFieldOptions.ALGORITHM.HNSW, + type=type, + dimension=dimension, + distance_metric=distance_metric, + initial_cap=initial_cap, + m=m, + ef_construction=ef_construction, + ef_runtime=ef_runtime, + epsilon=epsilon, + ) + + @property + def schema(self): + attr = [] + for k, v in vars(self).items(): + if k == "algorithm" or v is None: + continue + attr.extend( + [ + k.upper() if k != "dimension" else "DIM", + str(v) if not isinstance(v, Enum) else v.name, + ] + ) + + return " ".join([f"VECTOR {self.algorithm.name} {len(attr)}"] + attr) + + def Field( default: Any = Undefined, *, @@ -964,6 +1108,7 @@ def Field( sortable: Union[bool, UndefinedType] = Undefined, index: Union[bool, UndefinedType] = Undefined, full_text_search: Union[bool, UndefinedType] = Undefined, + vector_options: Optional[VectorFieldOptions] = None, schema_extra: Optional[Dict[str, Any]] = None, ) -> Any: current_schema_extra = schema_extra or {} @@ -991,6 +1136,7 @@ def Field( sortable=sortable, index=index, full_text_search=full_text_search, + vector_options=vector_options, **current_schema_extra, ) field_info._validate() @@ -1083,6 +1229,10 @@ class ModelMeta(ModelMetaclass): new_class._meta.primary_key = PrimaryKey( name=field_name, field=field ) + if field.field_info.vector_options: + score_attr = f"_{field_name}_score" + setattr(new_class, score_attr, None) + new_class.__annotations__[score_attr] = Union[float, None] if not getattr(new_class._meta, "global_key_prefix", None): new_class._meta.global_key_prefix = getattr( @@ -1216,8 +1366,12 @@ class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta): return cls._meta.database @classmethod - def find(cls, *expressions: Union[Any, Expression]) -> FindQuery: - return FindQuery(expressions=expressions, model=cls) + def find( + cls, + *expressions: Union[Any, Expression], + knn: Optional[KNNExpression] = None, + ) -> FindQuery: + return FindQuery(expressions=expressions, knn=knn, model=cls) @classmethod def from_redis(cls, res: Any): @@ -1237,7 +1391,7 @@ class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta): for i in range(1, len(res), step): if res[i + offset] is None: continue - fields = dict( + fields: Dict[str, str] = dict( zip( map(to_string, res[i + offset][::2]), map(to_string, res[i + offset][1::2]), @@ -1247,6 +1401,9 @@ class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta): if fields.get("$"): json_fields = json.loads(fields.pop("$")) doc = cls(**json_fields) + for k, v in fields.items(): + if k.startswith("__") and k.endswith("_score"): + setattr(doc, k[1:], float(v)) else: doc = cls(**fields) @@ -1474,7 +1631,13 @@ class HashModel(RedisModel, abc.ABC): embedded_cls = embedded_cls[0] schema = cls.schema_for_type(name, embedded_cls, field_info) elif any(issubclass(typ, t) for t in NUMERIC_TYPES): - schema = f"{name} NUMERIC" + vector_options: Optional[VectorFieldOptions] = getattr( + field_info, "vector_options", None + ) + if vector_options: + schema = f"{name} {vector_options.schema}" + else: + schema = f"{name} NUMERIC" elif issubclass(typ, str): if getattr(field_info, "full_text_search", False) is True: schema = ( @@ -1623,10 +1786,22 @@ class JsonModel(RedisModel, abc.ABC): # Not a class, probably a type annotation field_is_model = False + vector_options: Optional[VectorFieldOptions] = getattr( + field_info, "vector_options", None + ) + try: + is_vector = vector_options and any( + issubclass(get_args(typ)[0], t) for t in NUMERIC_TYPES + ) + except IndexError: + raise RedisModelError( + f"Vector field '{name}' must be annotated as a container type" + ) + # When we encounter a list or model field, we need to descend # into the values of the list or the fields of the model to # find any values marked as indexed. - if is_container_type: + if is_container_type and not is_vector: field_type = get_origin(typ) embedded_cls = get_args(typ) if not embedded_cls: @@ -1689,7 +1864,9 @@ class JsonModel(RedisModel, abc.ABC): ) # TODO: GEO field - if parent_is_container_type or parent_is_model_in_container: + if is_vector and vector_options: + schema = f"{path} AS {index_field_name} {vector_options.schema}" + elif parent_is_container_type or parent_is_model_in_container: if typ is not str: raise RedisModelError( "In this Preview release, list and tuple fields can only "