Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Jinja keys #132

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 75 additions & 1 deletion docetl/operations/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from rich.status import Status
import jsonschema
from pydantic import BaseModel

import jinja2

# FIXME: This should probably live in some utils module?
class classproperty(object):
Expand Down Expand Up @@ -62,6 +62,7 @@ def __init__(
"num_retries_on_validate_failure", 0
)
self.syntax_check()
self.compiled_configs = {}

# This must be overridden in a subclass
class schema(BaseModel, extra="allow"):
Expand Down Expand Up @@ -142,3 +143,76 @@ def gleaning_check(self) -> None:
raise ValueError(
"'validation_prompt' in 'gleaning' configuration cannot be empty"
)

def evaluate_expression(self, config_path, default_expression=None, default_value=None, **context):
"""Evaluates a jinja2 expression specified in the operation
config (or a default expression if not found) against a given
context.

config_path is itself a jinja2 expression, evaluated within
the context of the operation config (self.config) and one
additional variable: config, which is bound to the entire
config file.

Evaluating config_path should yield a jinja2 expression, or
alternatively a list of expressions (if so, a list of values
is returned).

If no expression is found and no default expression is
provided, or the expression evaluates to Undefined,
default_value is returned.

Example:

Assuming

config_path = "input.title_keys"
self.config = {"input": {"title_keys": ["title", "categories.0.title"]}}
context = {"input": {"title": "Hello", "categories": [{"title": "world"}]}}

this function will return ["Hello", "world"].

"""
if config_path not in self.compiled_configs:
env = jinja2.Environment()
expression = env.compile_expression(config_path)(
config=self.runner.config,
**self.config)
if expression is None:
expression = default_expression
if isinstance(expression, (list, tuple)):
self.compiled_configs[config_path] = [env.compile_expression(e, undefined_to_none=False) for e in expression]
elif isinstance(expression, str):
self.compiled_configs[config_path] = env.compile_expression(expression, undefined_to_none=False)
else:
self.compiled_configs[config_path] = None
expr = self.compiled_configs[config_path]
if expr is None:
return default_value
def expr_or_default(expr):
res = expr(**context)
if res is jinja2.Undefined:
return default_value
return res
if isinstance(expr, list):
return [expr_or_default(e) for e in expr]
return expr_or_default(expr)

def evaluate_template(self, config_path, default_template=None, **context):
"""Renders a jinja2 template specified in the operation config
(or a default template if not found) against a given context.

config_path is a jinja2 expression, evaluated within
the context of the operation config (self.config) and one
additional variable: config, which is bound to the entire
config file.
"""
if config_path not in self.compiled_configs:
env = jinja2.Environment()
template = env.compile_expression(config_path)(
config=self.runner.config,
**self.config)
if template is None:
template = default_template
self.compiled_configs[config_path] = jinja2.Template(template)
return self.compiled_configs[config_path].render(**context)
3 changes: 1 addition & 2 deletions docetl/operations/map.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,8 +161,7 @@ def execute(self, input_data: List[Dict]) -> Tuple[List[Dict], float]:
self.status.stop()

def _process_map_item(item: Dict) -> Tuple[Optional[Dict], float]:
prompt_template = Template(self.config["prompt"])
prompt = prompt_template.render(input=item)
prompt = self.evaluate_template("prompt", input=item)

def validation_fn(response: Dict[str, Any]):
output = self.runner.api.parse_llm_response(
Expand Down
46 changes: 27 additions & 19 deletions docetl/operations/resolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ def compare_pair(
model: str,
item1: Dict,
item2: Dict,
blocking_keys: List[str] = [],
timeout_seconds: int = 120,
max_retries_per_timeout: int = 2,
) -> Tuple[bool, float]:
Expand All @@ -63,14 +62,15 @@ def compare_pair(
Returns:
Tuple[bool, float]: A tuple containing a boolean indicating whether the items match and the cost of the comparison.
"""
if blocking_keys:
if all(
key in item1
and key in item2
and item1[key].lower() == item2[key].lower()
for key in blocking_keys
):
return True, 0
keys1 = self.get_blocking_keys(item1)
keys2 = self.get_blocking_keys(item1)

if all(key1 is not None
and key2 is not None
and key1.lower() == key2.lower()
for key1, key2 in zip(keys1, keys2)
):
return True, 0

prompt_template = Template(comparison_prompt)
prompt = prompt_template.render(input1=item1, input2=item2)
Expand All @@ -89,6 +89,16 @@ def compare_pair(
)[0]
return output["is_match"], response.total_cost

def get_blocking_keys(self, item):
return self.evaluate_expression(
"blocking_keys", default_expression=list(item.keys()), **item)

def get_hash_key(self, item):
"""Create a hashable key from the blocking keys"""
return tuple(
str(value) if value is not None else ""
for value in self.get_blocking_keys(item))

def syntax_check(self) -> None:
"""
Checks the configuration of the ResolveOperation for required keys and valid structure.
Expand Down Expand Up @@ -226,7 +236,6 @@ def execute(self, input_data: List[Dict]) -> Tuple[List[Dict], float]:
if len(input_data) == 0:
return [], 0

blocking_keys = self.config.get("blocking_keys", [])
blocking_threshold = self.config.get("blocking_threshold")
blocking_conditions = self.config.get("blocking_conditions", [])
if self.status:
Expand All @@ -243,9 +252,6 @@ def execute(self, input_data: List[Dict]) -> Tuple[List[Dict], float]:
raise ValueError("Operation cancelled by user.")

input_schema = self.config.get("input", {}).get("schema", {})
if not blocking_keys:
# Set them to all keys in the input data
blocking_keys = list(input_data[0].keys())
limit_comparisons = self.config.get("limit_comparisons")
total_cost = 0

Expand All @@ -264,7 +270,10 @@ def get_embeddings_batch(
items: List[Dict[str, Any]]
) -> List[Tuple[List[float], float]]:
texts = [
" ".join(str(item[key]) for key in blocking_keys if key in item)
" ".join(
str(value) for value
in self.get_blocking_keys(item)
if value is not None)
for item in items
]
response = self.runner.api.gen_embedding(
Expand All @@ -291,14 +300,14 @@ def get_embeddings_batch(
costs.extend([r[1] for r in result])

total_cost += sum(costs)

# Generate all pairs to compare, ensuring no duplicate comparisons
def get_unique_comparison_pairs():
# Create a mapping of values to their indices
value_to_indices = {}
for i, item in enumerate(input_data):
# Create a hashable key from the blocking keys
key = tuple(str(item.get(k, "")) for k in blocking_keys)
key = self.get_hash_key(item)
if key not in value_to_indices:
value_to_indices[key] = []
value_to_indices[key].append(i)
Expand Down Expand Up @@ -384,8 +393,8 @@ def merge_clusters(item1, item2):
clusters[root2] = set()

# Also merge all other indices that share the same values
key1 = tuple(str(input_data[item1].get(k, "")) for k in blocking_keys)
key2 = tuple(str(input_data[item2].get(k, "")) for k in blocking_keys)
key1 = self.get_hash_key(input_data[item1])
key2 = self.get_hash_key(input_data[item2])

# Merge all indices with the same values
for idx in value_to_indices.get(key1, []):
Expand Down Expand Up @@ -433,7 +442,6 @@ def merge_clusters(item1, item2):
self.config.get("comparison_model", self.default_model),
input_data[pair[0]],
input_data[pair[1]],
blocking_keys,
timeout_seconds=self.config.get("timeout", 120),
max_retries_per_timeout=self.config.get(
"max_retries_per_timeout", 2
Expand Down