Data Source

Data Source functions and classes

The Data Source step runs a set of queries against some data source.

The query is defined by the DataSourceFunction schema, which maps inputs List[Query] to outputs List[DataSourceResponse].

The DataSourceModule manages execution of a DataSourceFunction. The DataSourceModule gathers valid queries, sends them to the DataSourceFunction, and processes the results.


source

DataSourceModule

 DataSourceModule (function:Callable[[List[emb_opt.schemas.Query]],List[em
                   b_opt.schemas.DataSourceResponse]])

Module - module base class

Given an input Batch, the Module: 1. gathers inputs to the function 2. executes the function 3. validates the results of the function with output_schema 4. scatters results back into the Batch

Type Details
function typing.Callable[[typing.List[emb_opt.schemas.Query]], typing.List[emb_opt.schemas.DataSourceResponse]] data function
def build_batch():
    embeddings = [[0.1], [0.2], [0.3]]
    batch = build_batch_from_embeddings(embeddings)
    return batch

def data_source_test(queries: List[Query]) -> List[DataSourceResponse]:
    results = []
    for i, query in enumerate(queries):
        if i==0:
            response = DataSourceResponse(valid=False, data={'test':'test false response'},
                                         query_results=[Item.from_minimal(item='', embedding=[0.1])])
        elif i==1:
            response = DataSourceResponse(valid=True, data={'test':'test empty response'},
                                         query_results=[])
        elif i==2:
            response = DataSourceResponse(valid=True, data={'test':'test normal response'},
                                         query_results=[Item.from_minimal(item='1', embedding=[0.1]), 
                                                       Item.from_minimal(item='2', embedding=[0.2])])
        results.append(response)
    return results

batch = build_batch()
data_module = DataSourceModule(data_source_test)
batch2 = data_module(batch)
assert [i.internal.removed for i in batch2] == [True, True, False]

for q in batch2:
    for r in q:
        assert r.internal.parent_id == q.id

Given pydantic data parsing, the function can also return a json response

def json_test(queries: List[Query]) -> List[Dict]:
    results = [
        {
            'valid' : True,
            'data' : {},
            'query_results' : [
                {
                    'id' : None,
                    'item' : 'test',
                    'embedding' : [0.1],
                    'score' : None,
                    'data' : None
                }
            ]
        }
        for i in queries
    ]
    return results

batch = build_batch()
data_module = DataSourceModule(json_test)
batch2 = data_module(batch)

source

DataSourcePlugin

 DataSourcePlugin ()

DataSourcePlugin - documentation for plugin functions to DataSourceFunction

A valid DataSourceFunction is any function that maps List[Query] to List[DataSourceResponse]. The inputs will be given as Query objects. The outputs can be either a list of DataSourceResponse objects or a list of valid json dictionaries that match the DataSourceResponse schema

Query schema:

{ 'item' : Optional[Any], 'embedding' : List[float], 'data' : Optional[Dict], 'query_results': [] # will be empty at this stage }

Item schema:

{ 'id' : Optional[Union[str, int]] 'item' : Optional[Any], 'embedding' : List[float], 'score' : None, # will be None at this stage 'data' : Optional[Dict], }

Input schema:

List[Query]

DataSourceResponse schema:

{ 'valid' : bool, 'data' : Optional[Dict], 'query_results' : List[Item] }

Output schema:

List[DataSourceResponse]

The NumpyDataPlugin data source works with any numpy array of embeddings


source

NumpyDataPlugin

 NumpyDataPlugin (k:int, item_embeddings:numpy.ndarray,
                  item_data:Optional[List[Dict]]=None,
                  id_key:Optional[str]=None, item_key:Optional[str]=None,
                  distance_metric:str='euclidean',
                  distance_cutoff:Optional[float]=None)

NumpyDataPlugin - data plugin for working with numpy arrays. The data query will run k nearest neighbors of the query embeddings against the item_embeddings using distance_metric

Optionally, item_data can be provided as a list of dicts, where item_data[i] corresponds to the data for item_embeddings[i].

If item_data is provided, item_data[i]['id_key'] defines the ID of item i, and item_data[i]['item_key'] defines the specific item i

distance_metric is any valid scipy distance metric. see https://docs.scipy.org/doc/scipy/reference/generated/scipy.spatial.distance.cdist.html

if distance_cutoff is specified, query results with a distance greater than distance_cutoff are ignored

Type Default Details
k int k nearest neighbors to return
item_embeddings ndarray item embeddings
item_data typing.Optional[typing.List[typing.Dict]] None Optional dict of item data
id_key typing.Optional[str] None Optional key for item id (should be in item_data dict)
item_key typing.Optional[str] None Optional key for item value (should be in item_data dict)
distance_metric str euclidean distance metric, see options at https://docs.scipy.org/doc/scipy/reference/generated/scipy.spatial.distance.cdist.html
distance_cutoff typing.Optional[float] None query to result distance cutoff
n_vectors = 256
d_vectors = 64
k = 10
n_queries = 5

vectors = np.random.randn(n_vectors, d_vectors)
vector_data = [{'index':str(np.random.randint(0,1e6)), 
                'other':np.random.randint(0,1e3), 
                'item':str(np.random.randint(0,1e4))} 
               for i in range(vectors.shape[0])]

data_function = NumpyDataPlugin(k, vectors, vector_data, id_key='index', item_key='item', 
                                distance_metric='cosine', distance_cutoff=0.7)
data_module = DataSourceModule(data_function)

batch = build_batch_from_embeddings(np.random.randn(n_queries, d_vectors))
batch2 = data_module(batch)

for q in batch2:
    for r in q:
        assert r.internal.parent_id == q.id
        
assert all([max(i.data['query_distance'])<0.7 for i in batch2 if i.data['query_distance']])

source

DataPluginGradWrapper

 DataPluginGradWrapper (function:Callable[[List[emb_opt.schemas.Query]],Li
                        st[emb_opt.schemas.DataSourceResponse]],
                        lrs:numpy.ndarray)

DataPluginGradWrapper - wraps a DataSourceFunction to allow for gradient-based queries. The score gradient is used to generate hypothetical query embeddings following new_query = old_query + lr*grad

This should be used in conjunction with UpdatePluginGradientWrapper or an UpdateFunction that assigns the gradient to query.data['_score_grad']

Note that the gradient in this case is expected to point in the direction of increasing score. If using a custom gradient computation method, you may need to flip the sign of the gradient

Type Details
function typing.Callable[[typing.List[emb_opt.schemas.Query]], typing.List[emb_opt.schemas.DataSourceResponse]] data function to wrap
lrs ndarray array of learning rates