def build_batch():
embeddings = [[0.1], [0.2], [0.3]]
batch = build_batch_from_embeddings(embeddings)
return batch
def data_source_test(queries: List[Query]) -> List[DataSourceResponse]:
results = []
for i, query in enumerate(queries):
if i==0:
response = DataSourceResponse(valid=False, data={'test':'test false response'},
query_results=[Item.from_minimal(item='', embedding=[0.1])])
elif i==1:
response = DataSourceResponse(valid=True, data={'test':'test empty response'},
query_results=[])
elif i==2:
response = DataSourceResponse(valid=True, data={'test':'test normal response'},
query_results=[Item.from_minimal(item='1', embedding=[0.1]),
Item.from_minimal(item='2', embedding=[0.2])])
results.append(response)
return results
batch = build_batch()
data_module = DataSourceModule(data_source_test)
batch2 = data_module(batch)
assert [i.internal.removed for i in batch2] == [True, True, False]
for q in batch2:
for r in q:
assert r.internal.parent_id == q.idData Source
The Data Source step runs a set of queries against some data source.
The query is defined by the DataSourceFunction schema, which maps inputs List[Query] to outputs List[DataSourceResponse].
The DataSourceModule manages execution of a DataSourceFunction. The DataSourceModule gathers valid queries, sends them to the DataSourceFunction, and processes the results.
DataSourceModule
DataSourceModule (function:Callable[[List[emb_opt.schemas.Query]],List[em b_opt.schemas.DataSourceResponse]])
Module - module base class
Given an input Batch, the Module: 1. gathers inputs to the function 2. executes the function 3. validates the results of the function with output_schema 4. scatters results back into the Batch
| Type | Details | |
|---|---|---|
| function | typing.Callable[[typing.List[emb_opt.schemas.Query]], typing.List[emb_opt.schemas.DataSourceResponse]] | data function |
Given pydantic data parsing, the function can also return a json response
def json_test(queries: List[Query]) -> List[Dict]:
results = [
{
'valid' : True,
'data' : {},
'query_results' : [
{
'id' : None,
'item' : 'test',
'embedding' : [0.1],
'score' : None,
'data' : None
}
]
}
for i in queries
]
return results
batch = build_batch()
data_module = DataSourceModule(json_test)
batch2 = data_module(batch)DataSourcePlugin
DataSourcePlugin ()
DataSourcePlugin - documentation for plugin functions to DataSourceFunction
A valid DataSourceFunction is any function that maps List[Query] to List[DataSourceResponse]. The inputs will be given as Query objects. The outputs can be either a list of DataSourceResponse objects or a list of valid json dictionaries that match the DataSourceResponse schema
Query schema:
{ 'item' : Optional[Any], 'embedding' : List[float], 'data' : Optional[Dict], 'query_results': [] # will be empty at this stage }
Item schema:
{ 'id' : Optional[Union[str, int]] 'item' : Optional[Any], 'embedding' : List[float], 'score' : None, # will be None at this stage 'data' : Optional[Dict], }
Input schema:
List[Query]
DataSourceResponse schema:
{ 'valid' : bool, 'data' : Optional[Dict], 'query_results' : List[Item] }
Output schema:
List[DataSourceResponse]
The NumpyDataPlugin data source works with any numpy array of embeddings
NumpyDataPlugin
NumpyDataPlugin (k:int, item_embeddings:numpy.ndarray, item_data:Optional[List[Dict]]=None, id_key:Optional[str]=None, item_key:Optional[str]=None, distance_metric:str='euclidean', distance_cutoff:Optional[float]=None)
NumpyDataPlugin - data plugin for working with numpy arrays. The data query will run k nearest neighbors of the query embeddings against the item_embeddings using distance_metric
Optionally, item_data can be provided as a list of dicts, where item_data[i] corresponds to the data for item_embeddings[i].
If item_data is provided, item_data[i]['id_key'] defines the ID of item i, and item_data[i]['item_key'] defines the specific item i
distance_metric is any valid scipy distance metric. see https://docs.scipy.org/doc/scipy/reference/generated/scipy.spatial.distance.cdist.html
if distance_cutoff is specified, query results with a distance greater than distance_cutoff are ignored
| Type | Default | Details | |
|---|---|---|---|
| k | int | k nearest neighbors to return | |
| item_embeddings | ndarray | item embeddings | |
| item_data | typing.Optional[typing.List[typing.Dict]] | None | Optional dict of item data |
| id_key | typing.Optional[str] | None | Optional key for item id (should be in item_data dict) |
| item_key | typing.Optional[str] | None | Optional key for item value (should be in item_data dict) |
| distance_metric | str | euclidean | distance metric, see options at https://docs.scipy.org/doc/scipy/reference/generated/scipy.spatial.distance.cdist.html |
| distance_cutoff | typing.Optional[float] | None | query to result distance cutoff |
n_vectors = 256
d_vectors = 64
k = 10
n_queries = 5
vectors = np.random.randn(n_vectors, d_vectors)
vector_data = [{'index':str(np.random.randint(0,1e6)),
'other':np.random.randint(0,1e3),
'item':str(np.random.randint(0,1e4))}
for i in range(vectors.shape[0])]
data_function = NumpyDataPlugin(k, vectors, vector_data, id_key='index', item_key='item',
distance_metric='cosine', distance_cutoff=0.7)
data_module = DataSourceModule(data_function)
batch = build_batch_from_embeddings(np.random.randn(n_queries, d_vectors))
batch2 = data_module(batch)
for q in batch2:
for r in q:
assert r.internal.parent_id == q.id
assert all([max(i.data['query_distance'])<0.7 for i in batch2 if i.data['query_distance']])DataPluginGradWrapper
DataPluginGradWrapper (function:Callable[[List[emb_opt.schemas.Query]],Li st[emb_opt.schemas.DataSourceResponse]], lrs:numpy.ndarray)
DataPluginGradWrapper - wraps a DataSourceFunction to allow for gradient-based queries. The score gradient is used to generate hypothetical query embeddings following new_query = old_query + lr*grad
This should be used in conjunction with UpdatePluginGradientWrapper or an UpdateFunction that assigns the gradient to query.data['_score_grad']
Note that the gradient in this case is expected to point in the direction of increasing score. If using a custom gradient computation method, you may need to flip the sign of the gradient
| Type | Details | |
|---|---|---|
| function | typing.Callable[[typing.List[emb_opt.schemas.Query]], typing.List[emb_opt.schemas.DataSourceResponse]] | data function to wrap |
| lrs | ndarray | array of learning rates |