Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Jan 10, 2025
1 parent af633a2 commit 0ee8d40
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 30 deletions.
26 changes: 14 additions & 12 deletions examples/atlas/get_result_web.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import argparse
from functools import partial
import json
import os
from functools import partial
from pathlib import Path

import numpy as np
Expand Down Expand Up @@ -120,18 +120,19 @@ def spilt_web(url: str):
print("No match found")


def get_metric(run,metric_col):
def get_metric(run, metric_col):
"""Extract metric value from wandb run.
Parameters
----------
run : wandb.Run
Weights & Biases run object
Returns
-------
float
Metric value or negative infinity if metric not found
"""
if metric_col not in run.summary:
return float('-inf') # Return -inf for missing metrics to handle in comparisons
Expand All @@ -140,7 +141,7 @@ def get_metric(run,metric_col):

def get_best_method(urls, metric_col="test_acc"):
"""Find the best performing method across multiple wandb sweeps.
Parameters
----------
urls : list
Expand All @@ -163,11 +164,11 @@ def get_best_method(urls, metric_col="test_acc"):

# Track run statistics
run_states = {"all_total_runs": 0, "all_finished_runs": 0}

for step_name, url in zip(step_names, urls):
_, _, sweep_id = spilt_web(url)
sweep = wandb.Api(timeout=1000).sweep(f"{entity}/{project}/{sweep_id}")

# Update run statistics
finished_runs = [run for run in sweep.runs if run.state == "finished"]
run_states.update({
Expand All @@ -182,10 +183,10 @@ def get_best_method(urls, metric_col="test_acc"):
best_run = max(sweep.runs, key=partial(get_metric, metric_col=metric_col)) if goal == "maximize" else \
min(sweep.runs, key=partial(get_metric, metric_col=metric_col)) if goal == "minimize" else \
None

if best_run is None:
raise RuntimeError("Optimization goal must be either 'minimize' or 'maximize'")

if metric_col not in best_run.summary:
continue
if all_best_run is None:
Expand Down Expand Up @@ -323,10 +324,10 @@ def get_new_ans(tissue):

def write_ans(tissue, new_df, output_file=None):
"""Process and write results for a specific tissue type to CSV.
Handles merging of new results with existing data, including conflict detection
for metric values.
Parameters
----------
tissue : str
Expand All @@ -335,6 +336,7 @@ def write_ans(tissue, new_df, output_file=None):
New results to be written
output_file : str, optional
Output file path. Defaults to 'sweep_results/{tissue}_ans.csv'
"""
if output_file is None:
output_file = f"sweep_results/{tissue}_ans.csv"
Expand All @@ -345,7 +347,7 @@ def write_ans(tissue, new_df, output_file=None):

# Reset index to ensure Dataset_id is a regular column
new_df = new_df.reset_index(drop=True)

# Process new data by merging rows with same Dataset_id
new_df_processed = pd.DataFrame()
for dataset_id in new_df['Dataset_id'].unique():
Expand Down
35 changes: 18 additions & 17 deletions examples/atlas/sc_similarity_examples/sim_query_atlas.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,18 +76,19 @@ def is_match(config_str):

def is_matching_dict(yaml_str, target_dict):
"""Compare YAML configuration with target dictionary.
Parameters
----------
yaml_str : str
YAML configuration string to parse
target_dict : dict
Target dictionary to compare against
Returns
-------
bool
True if dictionaries match, False otherwise
"""
# Parse YAML string
yaml_config = yaml.safe_load(yaml_str)
Expand All @@ -107,18 +108,19 @@ def is_matching_dict(yaml_str, target_dict):

def get_ans(query_dataset, method):
"""Get test accuracy results for a given dataset and method.
Parameters
----------
query_dataset : str
Dataset identifier
method : str
Method name to analyze
Returns
-------
pandas.DataFrame or None
DataFrame containing test accuracy results, None if results don't exist
"""
result_path = f"{file_root}/tuning/{method}/{query_dataset}/results/atlas/best_test_acc.csv"
if not os.path.exists(result_path):
Expand All @@ -137,33 +139,33 @@ def get_ans(query_dataset, method):

def get_ans_from_cache(query_dataset, method):
"""Get cached test accuracy results for atlas datasets.
Parameters
----------
query_dataset : str
Query dataset identifier
method : str
Method name to analyze
Returns
-------
pandas.DataFrame
DataFrame containing test accuracy results from cache
"""
# Get best method from step2 of atlas datasets
# Search accuracy according to best method (all values should exist)
ans = pd.DataFrame(index=[method],
columns=[f"{atlas_dataset}_from_cache" for atlas_dataset in atlas_datasets])

ans = pd.DataFrame(index=[method], columns=[f"{atlas_dataset}_from_cache" for atlas_dataset in atlas_datasets])

sweep_url = re.search(r"step2:([^|]+)",
conf_data[conf_data["dataset_id"] == query_dataset][method].iloc[0]).group(1)
conf_data[conf_data["dataset_id"] == query_dataset][method].iloc[0]).group(1)
_, _, sweep_id = spilt_web(sweep_url)
sweep = wandb.Api().sweep(f"{entity}/{project}/{sweep_id}")

for atlas_dataset in atlas_datasets:
best_yaml = conf_data[conf_data["dataset_id"] == atlas_dataset][f"{method}_best_yaml"].iloc[0]
match_run = None

# Find matching run configuration
for run in sweep.runs:
if isinstance(best_yaml, float) and np.isnan(best_yaml):
Expand All @@ -172,14 +174,13 @@ def get_ans_from_cache(query_dataset, method):
if match_run is not None:
raise ValueError("Multiple matching runs found when only one expected")
match_run = run

if match_run is None:
logger.warning(f"No matching configuration found for {atlas_dataset} with method {method}")
else:
ans.loc[method, f"{atlas_dataset}_from_cache"] = (
match_run.summary["test_acc"] if "test_acc" in match_run.summary else np.nan
)

ans.loc[method, f"{atlas_dataset}_from_cache"] = (match_run.summary["test_acc"]
if "test_acc" in match_run.summary else np.nan)

return ans


Expand Down
2 changes: 1 addition & 1 deletion examples/atlas/test_get_result_web.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def test_write_ans(tmp_path):
# 测试冲突情况的处理
write_ans(tissue, conflict_df, output_file=output_file)
final_df = pd.read_csv(output_file)

# 验证新值被更新
assert final_df[final_df['Dataset_id'] == 'dataset1']['method1_best_res'].iloc[0] == 0.7

Expand Down

0 comments on commit 0ee8d40

Please sign in to comment.