def passthrough_update_test(queries):
return [UpdateResponse(query=i, parent_id=None) for i in queries]
batch = Batch(queries=[
Query.from_minimal(embedding=[0.1]),
Query.from_minimal(embedding=[0.2]),
Query.from_minimal(embedding=[0.3]),
])
update_module = UpdateModule(passthrough_update_test)
batch = update_module(batch)
assert isinstance(batch, Batch)
assert isinstance(batch[0], Query)Update
The Update step uses a set of queries and scored results to generate a new set of queries for the next iteration of the search.
Updates are denoted as discrete or continuous. continuous updates generate new query embeddings purely in embedding space (ie by averaging several embeddings). As a result, continuous update outputs do not have a specific item associated with them. discrete updates use a specific query result Item as the update, maintaining the item associated with it.
The update step is formalized by the UpdateFunction schema, which maps inputs List[Query] to outputs List[Query]. Note that the number of outputs can be different from the number of inputs.
The UpdateModule manages execution of a UpdateFunction. The UpdateModule gathers valid items, sends them to the UpdateFunction, and processes the results.
UpdateModule
UpdateModule (function:Callable[[List[emb_opt.schemas.Query]],List[emb_op t.schemas.UpdateResponse]])
Module - module base class
Given an input Batch, the Module: 1. gathers inputs to the function 2. executes the function 3. validates the results of the function with output_schema 4. scatters results back into the Batch
def continuous_update_test(queries):
outputs = []
for query in queries:
new_query = Query.from_minimal(embedding=[i*2 for i in query.embedding])
outputs.append(UpdateResponse(query=new_query, parent_id=query.id))
return outputs
batch = Batch(queries=[
Query.from_minimal(embedding=[0.1]),
Query.from_minimal(embedding=[0.2]),
Query.from_minimal(embedding=[0.3]),
])
[batch.queries[i].update_internal(collection_id=i) for i in range(len(batch))]
update_module = UpdateModule(continuous_update_test)
batch2 = update_module(batch)
assert all([batch2[i].internal.collection_id==batch[i].internal.collection_id for i in range(len(batch2))])
assert isinstance(batch2, Batch)
assert isinstance(batch2[0], Query)def discrete_update_test(queries):
return [UpdateResponse(query=Query.from_item(i[0]), parent_id=i.id) for i in queries]
queries = []
for i in range(3):
q = Query.from_minimal(embedding=[i*0.1])
q.update_internal(collection_id=i)
r = Item.from_minimal(embedding=[i*2*0.1])
q.add_query_results([r])
queries.append(q)
batch = Batch(queries=queries)
update_module = UpdateModule(discrete_update_test)
batch2 = update_module(batch)
assert isinstance(batch2, Batch)
assert isinstance(batch2[0], Query)
for i in range(len(batch2)):
assert batch2[i].internal.parent_id == batch[i][0].internal.parent_id
assert batch2[i].data['_source_item_id'] == batch[i][0].id
assert batch2[i].internal.collection_id == batch[i].internal.collection_idUpdatePlugin
UpdatePlugin ()
UpdatePlugin - documentation for plugin functions to UpdateFunction
A valid UpdateFunction is any function that maps List[Query] to List[Query]. The inputs will be given as Query objects. The outputs can be either a list of Query objects or a list of valid json dictionaries that match the Query schema. The number of outputs can be different from the number of inputs
Item schema:
{ 'id' : Optional[Union[str, int]] 'item' : Optional[Any], 'embedding' : List[float], 'score' : float, 'data' : Optional[Dict], }
Query schema:
{ 'item' : Optional[Any], 'embedding' : List[float], 'data' : Optional[Dict], 'query_results': List[Item] }
UpdateResponse schema:
{ 'query' : Query, 'parent_id' : Optional[str], }
Input schema:
List[Query]
Output schema:
List[UpdateResponse]
UpdatePluginGradientWrapper
UpdatePluginGradientWrapper (function:Callable[[List[emb_opt.schemas.Quer y]],List[emb_opt.schemas.UpdateResponse]], distance_penalty:float=0, max_norm:Optional[float]=None, norm_type:Union[float,int,str,NoneType]=2.0)
UpdatePluginGradientWrapper - this class wraps a valid UpdateFunction to estimate the gradient of new queries using the results and scores computed for the parent query.
This wrapper integrates with DataPluginGradWrapper, which allows us to create new query vectors based on the gradient
| Type | Default | Details | |
|---|---|---|---|
| function | typing.Callable[[typing.List[emb_opt.schemas.Query]], typing.List[emb_opt.schemas.UpdateResponse]] | UpdateFunction to wrap |
|
| distance_penalty | float | 0 | RL grad distance penalty |
| max_norm | typing.Optional[float] | None | max grad norm |
| norm_type | typing.Union[float, int, str, NoneType] | 2.0 | grad norm type |
TopKDiscreteUpdate
TopKDiscreteUpdate (k:int)
TopKDiscreteUpdate - discrete update that generates k new queries from the top k scoring items in each input query
| Type | Details | |
|---|---|---|
| k | int | top k items to return as new queries |
q1 = Query.from_minimal(embedding=[0.1])
q1.add_query_results([
Item(id=None, item='1', embedding=[0.11], score=-10, data=None),
Item(id=None, item='2', embedding=[0.12], score=6, data=None),
Item(id=None, item='3', embedding=[0.12], score=1, data=None),
])
q2 = Query.from_minimal(embedding=[0.2])
q2.add_query_results([
Item(id=None, item='4', embedding=[0.21], score=4, data=None),
Item(id=None, item='5', embedding=[0.22], score=5, data=None),
Item(id=None, item='6', embedding=[0.12], score=2, data=None),
])
batch = Batch(queries=[q1, q2])
update_func = TopKDiscreteUpdate(k=2)
update_module = UpdateModule(update_func)
batch2 = update_module(batch)
assert [i.item for i in batch2] == ['2', '3', '5', '4']
update_func2 = UpdatePluginGradientWrapper(update_func)
update_module2 = UpdateModule(update_func2)
batch3 = update_module2(batch)
q1.query_results[1].internal.removed = True
batch = Batch(queries=[q1, q2])
batch2 = update_module(batch)
assert [i.item for i in batch2] == ['3', '1', '5', '4']TopKContinuousUpdate
TopKContinuousUpdate (k:int)
TopKContinuousUpdate - continuous update that generates 1 new query by averaging the top k scoring item embeddings for each input query
| Type | Details | |
|---|---|---|
| k | int | top k items to average |
q1 = Query.from_minimal(embedding=[0.1])
q1.add_query_results([
Item(id=None, item='1', embedding=[0.1], score=-10, data=None),
Item(id=None, item='2', embedding=[0.2], score=6, data=None),
])
q2 = Query.from_minimal(embedding=[0.2])
q2.add_query_results([
Item(id=None, item='4', embedding=[0.2], score=4, data=None),
Item(id=None, item='5', embedding=[0.3], score=5, data=None),
])
batch = Batch(queries=[q1, q2])
update_func = TopKContinuousUpdate(k=2)
update_module = UpdateModule(update_func)
batch2 = update_module(batch)
assert np.allclose([i.embedding for i in batch2], [[0.15], [0.25]])
update_func = TopKContinuousUpdate(k=1)
update_module = UpdateModule(update_func)
batch2 = update_module(batch)
assert np.allclose([i.embedding for i in batch2], [[0.2], [0.3]])
update_func2 = UpdatePluginGradientWrapper(update_func)
update_module2 = UpdateModule(update_func2)
batch3 = update_module2(batch)RLUpdate
RLUpdate (lrs:Union[List[float],numpy.ndarray], distance_penalty:float, max_norm:Optional[float]=None, norm_type:Union[float,int,str,NoneType]=2.0)
RLUpdate - uses reinforcement learning to update queries
To compute the gradient with RL: 1. compute advantages by whitening scores 1. advantage[i] = (scores[i] - scores.mean()) / scores.std() 2. compute advantage loss 1. advantage_loss[i] = advantage[i] * (query_embedding - result_embedding[i])**2 3. compute distance loss 1. distance_loss[i] = distance_penalty * (query_embedding - result_embedding[i])**2 4. sum loss terms 1. loss[i] = advantage_loss[i] + distance_loss[i] 5. compute the gradient
This gives a closed for calculation of the gradient as:
grad[i] = 2 * (advantage[i] + distance_penalty) * (query_embedding - result_embedding[i])
| Type | Default | Details | |
|---|---|---|---|
| lrs | typing.Union[typing.List[float], numpy.ndarray] | list of learning rates | |
| distance_penalty | float | distance penalty coefficient | |
| max_norm | typing.Optional[float] | None | optional max grad norm for clipping |
| norm_type | typing.Union[float, int, str, NoneType] | 2.0 | norm type |
lrs = np.array([1e-2, 1e-1, 1e0, 1e1])
dp = 0.1
update_function = RLUpdate(lrs, dp, max_norm=1.)
update_module = UpdateModule(update_function)
queries = []
for i in range(1,4):
q = Query.from_minimal(embedding=[i*0.1, i*0.2, i*0.3, 0.1, 0.1])
q.update_internal(collection_id=i)
r1 = Item.from_minimal(embedding=[2*i*0.1, 2*i*0.2, 2*i*0.3, -0.1, -0.1], score=i*1.5)
r2 = Item.from_minimal(embedding=[-2*i*0.1, 2*i*0.2, -2*i*0.3, -0.1, -0.1], score=i*.5)
q.add_query_results([r1, r2])
queries.append(q)
batch = Batch(queries=queries)
batch2 = update_module(batch)
assert len(batch2)/len(batch) == len(lrs)