Prune

Prune functions and classes

The Prune step optionally removes queries prior to the update step. A Prune step allows for control over the total number of queries in the scenario where the Update step generates multiple output queries for each input.

The prune step is formalized by the PruneFunction schema, which maps inputs List[Query] to outputs List[PruneResponse].

The PruneModule manages execution of a PruneFunction. The PruneModule gathers valid items, sends them to the PruneFunction, and processes the results.


source

PruneModule

 PruneModule (function:Callable[[List[emb_opt.schemas.Query]],List[emb_opt
              .schemas.PruneResponse]])

Module - module base class

Given an input Batch, the Module: 1. gathers inputs to the function 2. executes the function 3. validates the results of the function with output_schema 4. scatters results back into the Batch

Type Details
function typing.Callable[[typing.List[emb_opt.schemas.Query]], typing.List[emb_opt.schemas.PruneResponse]] prune function
batch = Batch(queries=[
                        Query.from_minimal(embedding=[0.1]),
                        Query.from_minimal(embedding=[0.2]),
                        Query.from_minimal(embedding=[0.3]),
                    ])

def prune_func(queries):
    return [PruneResponse(valid=i.embedding[0]>=0.2, data=None) for i in queries]

prune_module = PruneModule(prune_func)

batch = prune_module(batch)

assert [i.internal.removed for i in batch] == [True, False, False]

source

PrunePlugin

 PrunePlugin ()

PrunePlugin - documentation for plugin functions to PruneFunction

A valid PruneFunction is any function that maps List[Query] to List[PruneResponse]. The inputs will be given as Query objects. The outputs can be either a list of PruneResponse objects or a list of valid json dictionaries that match the PruneResponse schema

The Prune step is called after scoring, so each result Item in the input queries will have a score assigned

Item schema:

{ 'id' : Optional[Union[str, int]] 'item' : Optional[Any], 'embedding' : List[float], 'score' : float, 'data' : Optional[Dict], }

Query schema:

{ 'item' : Optional[Any], 'embedding' : List[float], 'data' : Optional[Dict], 'query_results': List[Item] }

Input schema:

List[Query]

PruneResponse schema:

{ 'valid' : bool, 'data' : Optional[Dict], }

Output schema:

List[PruneResponse]


source

TopKPrune

 TopKPrune (k:int, score_agg:str='mean',
            group_by:Optional[str]='collection_id')

TopKPrune - keeps the top k best queries in each group by aggregated score

queries are first grouped by group_by * if group_by=None, all queries are considered the same group (global pruning) * if group_by='parent_id', queries are grouped by parent query id * if `group_by=‘collection_id’, queries are grouped by collection id

queries are then assigned a score based on aggregating query result scores * if score_agg='mean', each Query is scored by the average score of all Item results * if score_agg='max', each Query is scored by the max scoring Item result

Type Default Details
k int
score_agg str mean [‘mean’, ‘max’]
group_by typing.Optional[str] collection_id [None, ‘collection_id’, ‘parent_id’]
q1 = Query.from_minimal(embedding=[0.1])
q1.update_internal(collection_id=0)
q1.add_query_results([
    Item.from_minimal(embedding=[0.11], score=-10),
    Item.from_minimal(embedding=[0.12], score=6),
])

q2 = Query.from_minimal(embedding=[0.2])
q2.update_internal(collection_id=0)
q2.add_query_results([
    Item.from_minimal(embedding=[0.21], score=4),
    Item.from_minimal(embedding=[0.22], score=5),
])

q3 = Query.from_minimal(embedding=[0.3])
q3.update_internal(collection_id=1)
q3.add_query_results([
    Item.from_minimal(embedding=[0.31], score=7),
    Item.from_minimal(embedding=[0.32], score=8),
])

queries = [q1, q2, q3]

prune_func = TopKPrune(k=1, score_agg='mean', group_by=None)

assert [i.valid for i in prune_func(queries)] == [False, False, True]
q1 = Query.from_minimal(embedding=[0.1])
q1.update_internal(collection_id=0)
q1.add_query_results([
    Item.from_minimal(embedding=[0.11], score=-10),
    Item.from_minimal(embedding=[0.12], score=6),
])

q2 = Query.from_minimal(embedding=[0.2])
q2.update_internal(collection_id=0)
q2.add_query_results([
    Item.from_minimal(embedding=[0.21], score=4),
    Item.from_minimal(embedding=[0.22], score=5),
])

q3 = Query.from_minimal(embedding=[0.3])
q3.update_internal(collection_id=1)
q3.add_query_results([
    Item.from_minimal(embedding=[0.31], score=7),
    Item.from_minimal(embedding=[0.32], score=8),
])

queries = [q1, q2, q3]

prune_func = TopKPrune(k=1, score_agg='max', group_by='collection_id')

assert [i.valid for i in prune_func(queries)] == [True, False, True]

prune_func = TopKPrune(k=1, score_agg='mean', group_by='collection_id')

assert [i.valid for i in prune_func(queries)] == [False, True, True]