batch = Batch(queries=[
Query.from_minimal(embedding=[0.1]),
Query.from_minimal(embedding=[0.2]),
Query.from_minimal(embedding=[0.3]),
])
def prune_func(queries):
return [PruneResponse(valid=i.embedding[0]>=0.2, data=None) for i in queries]
prune_module = PruneModule(prune_func)
batch = prune_module(batch)
assert [i.internal.removed for i in batch] == [True, False, False]Prune
The Prune step optionally removes queries prior to the update step. A Prune step allows for control over the total number of queries in the scenario where the Update step generates multiple output queries for each input.
The prune step is formalized by the PruneFunction schema, which maps inputs List[Query] to outputs List[PruneResponse].
The PruneModule manages execution of a PruneFunction. The PruneModule gathers valid items, sends them to the PruneFunction, and processes the results.
PruneModule
PruneModule (function:Callable[[List[emb_opt.schemas.Query]],List[emb_opt .schemas.PruneResponse]])
Module - module base class
Given an input Batch, the Module: 1. gathers inputs to the function 2. executes the function 3. validates the results of the function with output_schema 4. scatters results back into the Batch
| Type | Details | |
|---|---|---|
| function | typing.Callable[[typing.List[emb_opt.schemas.Query]], typing.List[emb_opt.schemas.PruneResponse]] | prune function |
PrunePlugin
PrunePlugin ()
PrunePlugin - documentation for plugin functions to PruneFunction
A valid PruneFunction is any function that maps List[Query] to List[PruneResponse]. The inputs will be given as Query objects. The outputs can be either a list of PruneResponse objects or a list of valid json dictionaries that match the PruneResponse schema
The Prune step is called after scoring, so each result Item in the input queries will have a score assigned
Item schema:
{ 'id' : Optional[Union[str, int]] 'item' : Optional[Any], 'embedding' : List[float], 'score' : float, 'data' : Optional[Dict], }
Query schema:
{ 'item' : Optional[Any], 'embedding' : List[float], 'data' : Optional[Dict], 'query_results': List[Item] }
Input schema:
List[Query]
PruneResponse schema:
{ 'valid' : bool, 'data' : Optional[Dict], }
Output schema:
List[PruneResponse]
TopKPrune
TopKPrune (k:int, score_agg:str='mean', group_by:Optional[str]='collection_id')
TopKPrune - keeps the top k best queries in each group by aggregated score
queries are first grouped by group_by * if group_by=None, all queries are considered the same group (global pruning) * if group_by='parent_id', queries are grouped by parent query id * if `group_by=‘collection_id’, queries are grouped by collection id
queries are then assigned a score based on aggregating query result scores * if score_agg='mean', each Query is scored by the average score of all Item results * if score_agg='max', each Query is scored by the max scoring Item result
| Type | Default | Details | |
|---|---|---|---|
| k | int | ||
| score_agg | str | mean | [‘mean’, ‘max’] |
| group_by | typing.Optional[str] | collection_id | [None, ‘collection_id’, ‘parent_id’] |
q1 = Query.from_minimal(embedding=[0.1])
q1.update_internal(collection_id=0)
q1.add_query_results([
Item.from_minimal(embedding=[0.11], score=-10),
Item.from_minimal(embedding=[0.12], score=6),
])
q2 = Query.from_minimal(embedding=[0.2])
q2.update_internal(collection_id=0)
q2.add_query_results([
Item.from_minimal(embedding=[0.21], score=4),
Item.from_minimal(embedding=[0.22], score=5),
])
q3 = Query.from_minimal(embedding=[0.3])
q3.update_internal(collection_id=1)
q3.add_query_results([
Item.from_minimal(embedding=[0.31], score=7),
Item.from_minimal(embedding=[0.32], score=8),
])
queries = [q1, q2, q3]
prune_func = TopKPrune(k=1, score_agg='mean', group_by=None)
assert [i.valid for i in prune_func(queries)] == [False, False, True]q1 = Query.from_minimal(embedding=[0.1])
q1.update_internal(collection_id=0)
q1.add_query_results([
Item.from_minimal(embedding=[0.11], score=-10),
Item.from_minimal(embedding=[0.12], score=6),
])
q2 = Query.from_minimal(embedding=[0.2])
q2.update_internal(collection_id=0)
q2.add_query_results([
Item.from_minimal(embedding=[0.21], score=4),
Item.from_minimal(embedding=[0.22], score=5),
])
q3 = Query.from_minimal(embedding=[0.3])
q3.update_internal(collection_id=1)
q3.add_query_results([
Item.from_minimal(embedding=[0.31], score=7),
Item.from_minimal(embedding=[0.32], score=8),
])
queries = [q1, q2, q3]
prune_func = TopKPrune(k=1, score_agg='max', group_by='collection_id')
assert [i.valid for i in prune_func(queries)] == [True, False, True]
prune_func = TopKPrune(k=1, score_agg='mean', group_by='collection_id')
assert [i.valid for i in prune_func(queries)] == [False, True, True]