Update

Update functions and classes

The Update step uses a set of queries and scored results to generate a new set of queries for the next iteration of the search.

Updates are denoted as discrete or continuous. continuous updates generate new query embeddings purely in embedding space (ie by averaging several embeddings). As a result, continuous update outputs do not have a specific item associated with them. discrete updates use a specific query result Item as the update, maintaining the item associated with it.

The update step is formalized by the UpdateFunction schema, which maps inputs List[Query] to outputs List[Query]. Note that the number of outputs can be different from the number of inputs.

The UpdateModule manages execution of a UpdateFunction. The UpdateModule gathers valid items, sends them to the UpdateFunction, and processes the results.


source

UpdateModule

 UpdateModule (function:Callable[[List[emb_opt.schemas.Query]],List[emb_op
               t.schemas.UpdateResponse]])

Module - module base class

Given an input Batch, the Module: 1. gathers inputs to the function 2. executes the function 3. validates the results of the function with output_schema 4. scatters results back into the Batch

def passthrough_update_test(queries):
    return [UpdateResponse(query=i, parent_id=None) for i in queries]

batch = Batch(queries=[
                        Query.from_minimal(embedding=[0.1]),
                        Query.from_minimal(embedding=[0.2]),
                        Query.from_minimal(embedding=[0.3]),
                    ])

update_module = UpdateModule(passthrough_update_test)

batch = update_module(batch)

assert isinstance(batch, Batch)
assert isinstance(batch[0], Query)
def continuous_update_test(queries):
    outputs = []
    for query in queries:
        new_query = Query.from_minimal(embedding=[i*2 for i in query.embedding])
        outputs.append(UpdateResponse(query=new_query, parent_id=query.id))
    return outputs

batch = Batch(queries=[
                        Query.from_minimal(embedding=[0.1]),
                        Query.from_minimal(embedding=[0.2]),
                        Query.from_minimal(embedding=[0.3]),
                    ])

[batch.queries[i].update_internal(collection_id=i) for i in range(len(batch))]

update_module = UpdateModule(continuous_update_test)

batch2 = update_module(batch)

assert all([batch2[i].internal.collection_id==batch[i].internal.collection_id for i in range(len(batch2))])

assert isinstance(batch2, Batch)
assert isinstance(batch2[0], Query)
def discrete_update_test(queries):
    return [UpdateResponse(query=Query.from_item(i[0]), parent_id=i.id) for i in queries]

queries = []
for i in range(3):
    q = Query.from_minimal(embedding=[i*0.1])
    q.update_internal(collection_id=i)
    r = Item.from_minimal(embedding=[i*2*0.1])
    q.add_query_results([r])
    queries.append(q)
    
batch = Batch(queries=queries)

update_module = UpdateModule(discrete_update_test)

batch2 = update_module(batch)

assert isinstance(batch2, Batch)
assert isinstance(batch2[0], Query)

for i in range(len(batch2)):
    assert batch2[i].internal.parent_id == batch[i][0].internal.parent_id
    assert batch2[i].data['_source_item_id'] == batch[i][0].id
    assert batch2[i].internal.collection_id == batch[i].internal.collection_id

source

UpdatePlugin

 UpdatePlugin ()

UpdatePlugin - documentation for plugin functions to UpdateFunction

A valid UpdateFunction is any function that maps List[Query] to List[Query]. The inputs will be given as Query objects. The outputs can be either a list of Query objects or a list of valid json dictionaries that match the Query schema. The number of outputs can be different from the number of inputs

Item schema:

{ 'id' : Optional[Union[str, int]] 'item' : Optional[Any], 'embedding' : List[float], 'score' : float, 'data' : Optional[Dict], }

Query schema:

{ 'item' : Optional[Any], 'embedding' : List[float], 'data' : Optional[Dict], 'query_results': List[Item] }

UpdateResponse schema:

{ 'query' : Query, 'parent_id' : Optional[str], }

Input schema:

List[Query]

Output schema:

List[UpdateResponse]


source

UpdatePluginGradientWrapper

 UpdatePluginGradientWrapper (function:Callable[[List[emb_opt.schemas.Quer
                              y]],List[emb_opt.schemas.UpdateResponse]],
                              distance_penalty:float=0,
                              max_norm:Optional[float]=None,
                              norm_type:Union[float,int,str,NoneType]=2.0)

UpdatePluginGradientWrapper - this class wraps a valid UpdateFunction to estimate the gradient of new queries using the results and scores computed for the parent query.

This wrapper integrates with DataPluginGradWrapper, which allows us to create new query vectors based on the gradient

Type Default Details
function typing.Callable[[typing.List[emb_opt.schemas.Query]], typing.List[emb_opt.schemas.UpdateResponse]] UpdateFunction to wrap
distance_penalty float 0 RL grad distance penalty
max_norm typing.Optional[float] None max grad norm
norm_type typing.Union[float, int, str, NoneType] 2.0 grad norm type

source

TopKDiscreteUpdate

 TopKDiscreteUpdate (k:int)

TopKDiscreteUpdate - discrete update that generates k new queries from the top k scoring items in each input query

Type Details
k int top k items to return as new queries
q1 = Query.from_minimal(embedding=[0.1])
q1.add_query_results([
    Item(id=None, item='1', embedding=[0.11], score=-10, data=None),
    Item(id=None, item='2', embedding=[0.12], score=6, data=None),
    Item(id=None, item='3', embedding=[0.12], score=1, data=None),
])

q2 = Query.from_minimal(embedding=[0.2])
q2.add_query_results([
    Item(id=None, item='4', embedding=[0.21], score=4, data=None),
    Item(id=None, item='5', embedding=[0.22], score=5, data=None),
    Item(id=None, item='6', embedding=[0.12], score=2, data=None),
])

batch = Batch(queries=[q1, q2])

update_func = TopKDiscreteUpdate(k=2)
update_module = UpdateModule(update_func)
batch2 = update_module(batch)

assert [i.item for i in batch2] == ['2', '3', '5', '4']

update_func2 = UpdatePluginGradientWrapper(update_func)
update_module2 = UpdateModule(update_func2)
batch3 = update_module2(batch)

q1.query_results[1].internal.removed = True
batch = Batch(queries=[q1, q2])
batch2 = update_module(batch)
assert [i.item for i in batch2] == ['3', '1', '5', '4']

source

TopKContinuousUpdate

 TopKContinuousUpdate (k:int)

TopKContinuousUpdate - continuous update that generates 1 new query by averaging the top k scoring item embeddings for each input query

Type Details
k int top k items to average
q1 = Query.from_minimal(embedding=[0.1])
q1.add_query_results([
    Item(id=None, item='1', embedding=[0.1], score=-10, data=None),
    Item(id=None, item='2', embedding=[0.2], score=6, data=None),
])

q2 = Query.from_minimal(embedding=[0.2])
q2.add_query_results([
    Item(id=None, item='4', embedding=[0.2], score=4, data=None),
    Item(id=None, item='5', embedding=[0.3], score=5, data=None),
])

batch = Batch(queries=[q1, q2])

update_func = TopKContinuousUpdate(k=2)
update_module = UpdateModule(update_func)
batch2 = update_module(batch)

assert np.allclose([i.embedding for i in batch2], [[0.15], [0.25]])

update_func = TopKContinuousUpdate(k=1)
update_module = UpdateModule(update_func)
batch2 = update_module(batch)

assert np.allclose([i.embedding for i in batch2], [[0.2], [0.3]])

update_func2 = UpdatePluginGradientWrapper(update_func)
update_module2 = UpdateModule(update_func2)
batch3 = update_module2(batch)

source

RLUpdate

 RLUpdate (lrs:Union[List[float],numpy.ndarray], distance_penalty:float,
           max_norm:Optional[float]=None,
           norm_type:Union[float,int,str,NoneType]=2.0)

RLUpdate - uses reinforcement learning to update queries

To compute the gradient with RL: 1. compute advantages by whitening scores 1. advantage[i] = (scores[i] - scores.mean()) / scores.std() 2. compute advantage loss 1. advantage_loss[i] = advantage[i] * (query_embedding - result_embedding[i])**2 3. compute distance loss 1. distance_loss[i] = distance_penalty * (query_embedding - result_embedding[i])**2 4. sum loss terms 1. loss[i] = advantage_loss[i] + distance_loss[i] 5. compute the gradient

This gives a closed for calculation of the gradient as:

grad[i] = 2 * (advantage[i] + distance_penalty) * (query_embedding - result_embedding[i])

Type Default Details
lrs typing.Union[typing.List[float], numpy.ndarray] list of learning rates
distance_penalty float distance penalty coefficient
max_norm typing.Optional[float] None optional max grad norm for clipping
norm_type typing.Union[float, int, str, NoneType] 2.0 norm type
lrs = np.array([1e-2, 1e-1, 1e0, 1e1])
dp = 0.1

update_function = RLUpdate(lrs, dp, max_norm=1.)

update_module = UpdateModule(update_function)

queries = []
for i in range(1,4):
    q = Query.from_minimal(embedding=[i*0.1, i*0.2, i*0.3, 0.1, 0.1])
    q.update_internal(collection_id=i)
    r1 = Item.from_minimal(embedding=[2*i*0.1, 2*i*0.2, 2*i*0.3, -0.1, -0.1], score=i*1.5)
    r2 = Item.from_minimal(embedding=[-2*i*0.1, 2*i*0.2, -2*i*0.3, -0.1, -0.1], score=i*.5)
    q.add_query_results([r1, r2])
    queries.append(q)
    
batch = Batch(queries=queries)

batch2 = update_module(batch)

assert len(batch2)/len(batch) == len(lrs)