diff --git a/docetl/operations/base.py b/docetl/operations/base.py index c9272d75..6f400bce 100644 --- a/docetl/operations/base.py +++ b/docetl/operations/base.py @@ -10,7 +10,7 @@ from rich.status import Status import jsonschema from pydantic import BaseModel - +import jinja2 # FIXME: This should probably live in some utils module? class classproperty(object): @@ -62,6 +62,7 @@ def __init__( "num_retries_on_validate_failure", 0 ) self.syntax_check() + self.compiled_configs = {} # This must be overridden in a subclass class schema(BaseModel, extra="allow"): @@ -142,3 +143,76 @@ def gleaning_check(self) -> None: raise ValueError( "'validation_prompt' in 'gleaning' configuration cannot be empty" ) + + def evaluate_expression(self, config_path, default_expression=None, default_value=None, **context): + """Evaluates a jinja2 expression specified in the operation + config (or a default expression if not found) against a given + context. + + config_path is itself a jinja2 expression, evaluated within + the context of the operation config (self.config) and one + additional variable: config, which is bound to the entire + config file. + + Evaluating config_path should yield a jinja2 expression, or + alternatively a list of expressions (if so, a list of values + is returned). + + If no expression is found and no default expression is + provided, or the expression evaluates to Undefined, + default_value is returned. + + Example: + + Assuming + + config_path = "input.title_keys" + self.config = {"input": {"title_keys": ["title", "categories.0.title"]}} + context = {"input": {"title": "Hello", "categories": [{"title": "world"}]}} + + this function will return ["Hello", "world"]. + + """ + if config_path not in self.compiled_configs: + env = jinja2.Environment() + expression = env.compile_expression(config_path)( + config=self.runner.config, + **self.config) + if expression is None: + expression = default_expression + if isinstance(expression, (list, tuple)): + self.compiled_configs[config_path] = [env.compile_expression(e, undefined_to_none=False) for e in expression] + elif isinstance(expression, str): + self.compiled_configs[config_path] = env.compile_expression(expression, undefined_to_none=False) + else: + self.compiled_configs[config_path] = None + expr = self.compiled_configs[config_path] + if expr is None: + return default_value + def expr_or_default(expr): + res = expr(**context) + if res is jinja2.Undefined: + return default_value + return res + if isinstance(expr, list): + return [expr_or_default(e) for e in expr] + return expr_or_default(expr) + + def evaluate_template(self, config_path, default_template=None, **context): + """Renders a jinja2 template specified in the operation config + (or a default template if not found) against a given context. + + config_path is a jinja2 expression, evaluated within + the context of the operation config (self.config) and one + additional variable: config, which is bound to the entire + config file. + """ + if config_path not in self.compiled_configs: + env = jinja2.Environment() + template = env.compile_expression(config_path)( + config=self.runner.config, + **self.config) + if template is None: + template = default_template + self.compiled_configs[config_path] = jinja2.Template(template) + return self.compiled_configs[config_path].render(**context) diff --git a/docetl/operations/map.py b/docetl/operations/map.py index 7073ff3e..5ef30fd1 100644 --- a/docetl/operations/map.py +++ b/docetl/operations/map.py @@ -161,8 +161,7 @@ def execute(self, input_data: List[Dict]) -> Tuple[List[Dict], float]: self.status.stop() def _process_map_item(item: Dict) -> Tuple[Optional[Dict], float]: - prompt_template = Template(self.config["prompt"]) - prompt = prompt_template.render(input=item) + prompt = self.evaluate_template("prompt", input=item) def validation_fn(response: Dict[str, Any]): output = self.runner.api.parse_llm_response( diff --git a/docetl/operations/resolve.py b/docetl/operations/resolve.py index fff71343..ba7bf45c 100644 --- a/docetl/operations/resolve.py +++ b/docetl/operations/resolve.py @@ -47,7 +47,6 @@ def compare_pair( model: str, item1: Dict, item2: Dict, - blocking_keys: List[str] = [], timeout_seconds: int = 120, max_retries_per_timeout: int = 2, ) -> Tuple[bool, float]: @@ -63,14 +62,15 @@ def compare_pair( Returns: Tuple[bool, float]: A tuple containing a boolean indicating whether the items match and the cost of the comparison. """ - if blocking_keys: - if all( - key in item1 - and key in item2 - and item1[key].lower() == item2[key].lower() - for key in blocking_keys - ): - return True, 0 + keys1 = self.get_blocking_keys(item1) + keys2 = self.get_blocking_keys(item1) + + if all(key1 is not None + and key2 is not None + and key1.lower() == key2.lower() + for key1, key2 in zip(keys1, keys2) + ): + return True, 0 prompt_template = Template(comparison_prompt) prompt = prompt_template.render(input1=item1, input2=item2) @@ -89,6 +89,16 @@ def compare_pair( )[0] return output["is_match"], response.total_cost + def get_blocking_keys(self, item): + return self.evaluate_expression( + "blocking_keys", default_expression=list(item.keys()), **item) + + def get_hash_key(self, item): + """Create a hashable key from the blocking keys""" + return tuple( + str(value) if value is not None else "" + for value in self.get_blocking_keys(item)) + def syntax_check(self) -> None: """ Checks the configuration of the ResolveOperation for required keys and valid structure. @@ -226,7 +236,6 @@ def execute(self, input_data: List[Dict]) -> Tuple[List[Dict], float]: if len(input_data) == 0: return [], 0 - blocking_keys = self.config.get("blocking_keys", []) blocking_threshold = self.config.get("blocking_threshold") blocking_conditions = self.config.get("blocking_conditions", []) if self.status: @@ -243,9 +252,6 @@ def execute(self, input_data: List[Dict]) -> Tuple[List[Dict], float]: raise ValueError("Operation cancelled by user.") input_schema = self.config.get("input", {}).get("schema", {}) - if not blocking_keys: - # Set them to all keys in the input data - blocking_keys = list(input_data[0].keys()) limit_comparisons = self.config.get("limit_comparisons") total_cost = 0 @@ -264,7 +270,10 @@ def get_embeddings_batch( items: List[Dict[str, Any]] ) -> List[Tuple[List[float], float]]: texts = [ - " ".join(str(item[key]) for key in blocking_keys if key in item) + " ".join( + str(value) for value + in self.get_blocking_keys(item) + if value is not None) for item in items ] response = self.runner.api.gen_embedding( @@ -291,14 +300,14 @@ def get_embeddings_batch( costs.extend([r[1] for r in result]) total_cost += sum(costs) - + # Generate all pairs to compare, ensuring no duplicate comparisons def get_unique_comparison_pairs(): # Create a mapping of values to their indices value_to_indices = {} for i, item in enumerate(input_data): # Create a hashable key from the blocking keys - key = tuple(str(item.get(k, "")) for k in blocking_keys) + key = self.get_hash_key(item) if key not in value_to_indices: value_to_indices[key] = [] value_to_indices[key].append(i) @@ -384,8 +393,8 @@ def merge_clusters(item1, item2): clusters[root2] = set() # Also merge all other indices that share the same values - key1 = tuple(str(input_data[item1].get(k, "")) for k in blocking_keys) - key2 = tuple(str(input_data[item2].get(k, "")) for k in blocking_keys) + key1 = self.get_hash_key(input_data[item1]) + key2 = self.get_hash_key(input_data[item2]) # Merge all indices with the same values for idx in value_to_indices.get(key1, []): @@ -433,7 +442,6 @@ def merge_clusters(item1, item2): self.config.get("comparison_model", self.default_model), input_data[pair[0]], input_data[pair[1]], - blocking_keys, timeout_seconds=self.config.get("timeout", 120), max_retries_per_timeout=self.config.get( "max_retries_per_timeout", 2