Skip to content

Commit

Permalink
git merge
Browse files Browse the repository at this point in the history
  • Loading branch information
xingzhongyu committed Jan 11, 2025
2 parents 3336d4c + ffc84b2 commit fadbea6
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 31 deletions.
22 changes: 12 additions & 10 deletions examples/atlas/get_result_web.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import argparse
from functools import partial
import json
import os
from functools import partial
from pathlib import Path

import numpy as np
Expand Down Expand Up @@ -120,18 +120,19 @@ def spilt_web(url: str):
print("No match found")


def get_metric(run,metric_col):
def get_metric(run, metric_col):
"""Extract metric value from wandb run.
Parameters
----------
run : wandb.Run
Weights & Biases run object
Returns
-------
float
Metric value or negative infinity if metric not found
"""
if metric_col not in run.summary:
return float('-inf') # Return -inf for missing metrics to handle in comparisons
Expand All @@ -140,7 +141,7 @@ def get_metric(run,metric_col):

def get_best_method(urls, metric_col="test_acc"):
"""Find the best performing method across multiple wandb sweeps.
Parameters
----------
urls : list
Expand All @@ -163,11 +164,11 @@ def get_best_method(urls, metric_col="test_acc"):

# Track run statistics
run_states = {"all_total_runs": 0, "all_finished_runs": 0}

for step_name, url in zip(step_names, urls):
_, _, sweep_id = spilt_web(url)
sweep = wandb.Api(timeout=1000).sweep(f"{entity}/{project}/{sweep_id}")

# Update run statistics
finished_runs = [run for run in sweep.runs if run.state == "finished"]
run_states.update({
Expand All @@ -182,10 +183,10 @@ def get_best_method(urls, metric_col="test_acc"):
best_run = max(sweep.runs, key=partial(get_metric, metric_col=metric_col)) if goal == "maximize" else \
min(sweep.runs, key=partial(get_metric, metric_col=metric_col)) if goal == "minimize" else \
None

if best_run is None:
raise RuntimeError("Optimization goal must be either 'minimize' or 'maximize'")

if metric_col not in best_run.summary:
continue
if all_best_run is None:
Expand Down Expand Up @@ -335,6 +336,7 @@ def write_ans(tissue, new_df, output_file=None):
New results to be written
output_file : str, optional
Output file path. Defaults to 'sweep_results/{tissue}_ans.csv'
"""
if output_file is None:
output_file = f"sweep_results/{tissue}_ans.csv"
Expand All @@ -345,7 +347,7 @@ def write_ans(tissue, new_df, output_file=None):

# Reset index to ensure Dataset_id is a regular column
new_df = new_df.reset_index(drop=True)

# Process new data by merging rows with same Dataset_id
new_df_processed = pd.DataFrame()
for dataset_id in new_df['Dataset_id'].unique():
Expand Down
14 changes: 12 additions & 2 deletions examples/atlas/sc_similarity_examples/example_usage_anndata.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@
file_root = Path(__file__).resolve().parent
set_seed(42)
tissue = args.tissue
conf_data = pd.read_csv(f"results/{tissue}_result.csv", index_col=0)
# conf_data = pd.read_csv(f"results/{tissue}_result.csv", index_col=0)
conf_data = pd.read_excel("Cell Type Annotation Atlas.xlsx", sheet_name=tissue)
target_files = list(conf_data[conf_data["queryed"] == False]["dataset_id"])
source_files = list(conf_data[conf_data["queryed"] == True]["dataset_id"])

Expand Down Expand Up @@ -149,4 +150,13 @@ def run_test_case(source_file):
]
ans = run_test_case(source_file)
merged_df = pd.concat(query_ans + [ans], join='inner')
merged_df.to_excel(writer, sheet_name=source_file[:4], index=True)
try:
# 尝试读取指定的分表
existing_df = pd.read_excel(file_root / f"{tissue}_similarity.xlsx", sheet_name=source_file[:4])
# 找出在新数据框中存在但在现有表格中不存在的行
merged_df = pd.concat([existing_df, merged_df])
merged_df = merged_df.drop_duplicates(keep='first')
# 使用 ExcelWriter 更新特定分表
merged_df.to_excel(writer, sheet_name=source_file[:4], index=False)
except ValueError:
merged_df.to_excel(writer, sheet_name=source_file[:4], index=True)
39 changes: 20 additions & 19 deletions examples/atlas/sc_similarity_examples/sim_query_atlas.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,18 +76,19 @@ def is_match(config_str):

def is_matching_dict(yaml_str, target_dict):
"""Compare YAML configuration with target dictionary.
Parameters
----------
yaml_str : str
YAML configuration string to parse
target_dict : dict
Target dictionary to compare against
Returns
-------
bool
True if dictionaries match, False otherwise
"""
# Parse YAML string
yaml_config = yaml.safe_load(yaml_str)
Expand All @@ -107,18 +108,19 @@ def is_matching_dict(yaml_str, target_dict):

def get_ans(query_dataset, method):
"""Get test accuracy results for a given dataset and method.
Parameters
----------
query_dataset : str
Dataset identifier
method : str
Method name to analyze
Returns
-------
pandas.DataFrame or None
DataFrame containing test accuracy results, None if results don't exist
"""
result_path = f"{file_root}/tuning/{method}/{query_dataset}/results/atlas/best_test_acc.csv"
if not os.path.exists(result_path):
Expand All @@ -137,33 +139,33 @@ def get_ans(query_dataset, method):

def get_ans_from_cache(query_dataset, method):
"""Get cached test accuracy results for atlas datasets.
Parameters
----------
query_dataset : str
Query dataset identifier
method : str
Method name to analyze
Returns
-------
pandas.DataFrame
DataFrame containing test accuracy results from cache
"""
# Get best method from step2 of atlas datasets
# Search accuracy according to best method (all values should exist)
ans = pd.DataFrame(index=[method],
columns=[f"{atlas_dataset}_from_cache" for atlas_dataset in atlas_datasets])

ans = pd.DataFrame(index=[method], columns=[f"{atlas_dataset}_from_cache" for atlas_dataset in atlas_datasets])

sweep_url = re.search(r"step2:([^|]+)",
conf_data[conf_data["dataset_id"] == query_dataset][method].iloc[0]).group(1)
conf_data[conf_data["dataset_id"] == query_dataset][method].iloc[0]).group(1)
_, _, sweep_id = spilt_web(sweep_url)
sweep = wandb.Api().sweep(f"{entity}/{project}/{sweep_id}")

for atlas_dataset in atlas_datasets:
best_yaml = conf_data[conf_data["dataset_id"] == atlas_dataset][f"{method}_best_yaml"].iloc[0]
match_run = None

# Find matching run configuration
for run in sweep.runs:
if isinstance(best_yaml, float) and np.isnan(best_yaml):
Expand All @@ -172,14 +174,13 @@ def get_ans_from_cache(query_dataset, method):
if match_run is not None:
raise ValueError("Multiple matching runs found when only one expected")
match_run = run

if match_run is None:
logger.warning(f"No matching configuration found for {atlas_dataset} with method {method}")
else:
ans.loc[method, f"{atlas_dataset}_from_cache"] = (
match_run.summary["test_acc"] if "test_acc" in match_run.summary else np.nan
)

ans.loc[method, f"{atlas_dataset}_from_cache"] = (match_run.summary["test_acc"]
if "test_acc" in match_run.summary else np.nan)

return ans


Expand Down Expand Up @@ -207,8 +208,8 @@ def get_ans_from_cache(query_dataset, method):
# "738942eb-ac72-44ff-a64b-8943b5ecd8d9", "a5d95a42-0137-496f-8a60-101e17f263c8",
# "71be997d-ff75-41b9-8a9f-1288c865f921"
# ]
# conf_data = pd.read_excel("Cell Type Annotation Atlas.xlsx", sheet_name=tissue)
conf_data = pd.read_csv(f"results/{tissue}_result.csv", index_col=0)
conf_data = pd.read_excel("Cell Type Annotation Atlas.xlsx", sheet_name=tissue)
# conf_data = pd.read_csv(f"results/{tissue}_result.csv", index_col=0)
atlas_datasets = list(conf_data[conf_data["queryed"] == False]["dataset_id"])
query_datasets = list(conf_data[conf_data["queryed"] == True]["dataset_id"])
if __name__ == "__main__":
Expand Down

0 comments on commit fadbea6

Please sign in to comment.