From 0ee8d409bd22e0d38ee1dc1d7a316f6cab506fac Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 10 Jan 2025 12:54:06 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- examples/atlas/get_result_web.py | 26 +++++++------- .../sc_similarity_examples/sim_query_atlas.py | 35 ++++++++++--------- examples/atlas/test_get_result_web.py | 2 +- 3 files changed, 33 insertions(+), 30 deletions(-) diff --git a/examples/atlas/get_result_web.py b/examples/atlas/get_result_web.py index a479d9de..d4d2bd4d 100644 --- a/examples/atlas/get_result_web.py +++ b/examples/atlas/get_result_web.py @@ -1,7 +1,7 @@ import argparse -from functools import partial import json import os +from functools import partial from pathlib import Path import numpy as np @@ -120,18 +120,19 @@ def spilt_web(url: str): print("No match found") -def get_metric(run,metric_col): +def get_metric(run, metric_col): """Extract metric value from wandb run. - + Parameters ---------- run : wandb.Run Weights & Biases run object - + Returns ------- float Metric value or negative infinity if metric not found + """ if metric_col not in run.summary: return float('-inf') # Return -inf for missing metrics to handle in comparisons @@ -140,7 +141,7 @@ def get_metric(run,metric_col): def get_best_method(urls, metric_col="test_acc"): """Find the best performing method across multiple wandb sweeps. - + Parameters ---------- urls : list @@ -163,11 +164,11 @@ def get_best_method(urls, metric_col="test_acc"): # Track run statistics run_states = {"all_total_runs": 0, "all_finished_runs": 0} - + for step_name, url in zip(step_names, urls): _, _, sweep_id = spilt_web(url) sweep = wandb.Api(timeout=1000).sweep(f"{entity}/{project}/{sweep_id}") - + # Update run statistics finished_runs = [run for run in sweep.runs if run.state == "finished"] run_states.update({ @@ -182,10 +183,10 @@ def get_best_method(urls, metric_col="test_acc"): best_run = max(sweep.runs, key=partial(get_metric, metric_col=metric_col)) if goal == "maximize" else \ min(sweep.runs, key=partial(get_metric, metric_col=metric_col)) if goal == "minimize" else \ None - + if best_run is None: raise RuntimeError("Optimization goal must be either 'minimize' or 'maximize'") - + if metric_col not in best_run.summary: continue if all_best_run is None: @@ -323,10 +324,10 @@ def get_new_ans(tissue): def write_ans(tissue, new_df, output_file=None): """Process and write results for a specific tissue type to CSV. - + Handles merging of new results with existing data, including conflict detection for metric values. - + Parameters ---------- tissue : str @@ -335,6 +336,7 @@ def write_ans(tissue, new_df, output_file=None): New results to be written output_file : str, optional Output file path. Defaults to 'sweep_results/{tissue}_ans.csv' + """ if output_file is None: output_file = f"sweep_results/{tissue}_ans.csv" @@ -345,7 +347,7 @@ def write_ans(tissue, new_df, output_file=None): # Reset index to ensure Dataset_id is a regular column new_df = new_df.reset_index(drop=True) - + # Process new data by merging rows with same Dataset_id new_df_processed = pd.DataFrame() for dataset_id in new_df['Dataset_id'].unique(): diff --git a/examples/atlas/sc_similarity_examples/sim_query_atlas.py b/examples/atlas/sc_similarity_examples/sim_query_atlas.py index 07f1b9da..e7024a93 100644 --- a/examples/atlas/sc_similarity_examples/sim_query_atlas.py +++ b/examples/atlas/sc_similarity_examples/sim_query_atlas.py @@ -76,18 +76,19 @@ def is_match(config_str): def is_matching_dict(yaml_str, target_dict): """Compare YAML configuration with target dictionary. - + Parameters ---------- yaml_str : str YAML configuration string to parse target_dict : dict Target dictionary to compare against - + Returns ------- bool True if dictionaries match, False otherwise + """ # Parse YAML string yaml_config = yaml.safe_load(yaml_str) @@ -107,18 +108,19 @@ def is_matching_dict(yaml_str, target_dict): def get_ans(query_dataset, method): """Get test accuracy results for a given dataset and method. - + Parameters ---------- query_dataset : str Dataset identifier method : str Method name to analyze - + Returns ------- pandas.DataFrame or None DataFrame containing test accuracy results, None if results don't exist + """ result_path = f"{file_root}/tuning/{method}/{query_dataset}/results/atlas/best_test_acc.csv" if not os.path.exists(result_path): @@ -137,33 +139,33 @@ def get_ans(query_dataset, method): def get_ans_from_cache(query_dataset, method): """Get cached test accuracy results for atlas datasets. - + Parameters ---------- query_dataset : str Query dataset identifier method : str Method name to analyze - + Returns ------- pandas.DataFrame DataFrame containing test accuracy results from cache + """ # Get best method from step2 of atlas datasets # Search accuracy according to best method (all values should exist) - ans = pd.DataFrame(index=[method], - columns=[f"{atlas_dataset}_from_cache" for atlas_dataset in atlas_datasets]) - + ans = pd.DataFrame(index=[method], columns=[f"{atlas_dataset}_from_cache" for atlas_dataset in atlas_datasets]) + sweep_url = re.search(r"step2:([^|]+)", - conf_data[conf_data["dataset_id"] == query_dataset][method].iloc[0]).group(1) + conf_data[conf_data["dataset_id"] == query_dataset][method].iloc[0]).group(1) _, _, sweep_id = spilt_web(sweep_url) sweep = wandb.Api().sweep(f"{entity}/{project}/{sweep_id}") - + for atlas_dataset in atlas_datasets: best_yaml = conf_data[conf_data["dataset_id"] == atlas_dataset][f"{method}_best_yaml"].iloc[0] match_run = None - + # Find matching run configuration for run in sweep.runs: if isinstance(best_yaml, float) and np.isnan(best_yaml): @@ -172,14 +174,13 @@ def get_ans_from_cache(query_dataset, method): if match_run is not None: raise ValueError("Multiple matching runs found when only one expected") match_run = run - + if match_run is None: logger.warning(f"No matching configuration found for {atlas_dataset} with method {method}") else: - ans.loc[method, f"{atlas_dataset}_from_cache"] = ( - match_run.summary["test_acc"] if "test_acc" in match_run.summary else np.nan - ) - + ans.loc[method, f"{atlas_dataset}_from_cache"] = (match_run.summary["test_acc"] + if "test_acc" in match_run.summary else np.nan) + return ans diff --git a/examples/atlas/test_get_result_web.py b/examples/atlas/test_get_result_web.py index 3bd817ae..a0e9a8b4 100644 --- a/examples/atlas/test_get_result_web.py +++ b/examples/atlas/test_get_result_web.py @@ -108,7 +108,7 @@ def test_write_ans(tmp_path): # 测试冲突情况的处理 write_ans(tissue, conflict_df, output_file=output_file) final_df = pd.read_csv(output_file) - + # 验证新值被更新 assert final_df[final_df['Dataset_id'] == 'dataset1']['method1_best_res'].iloc[0] == 0.7