Skip to content

Commit

Permalink
update ans
Browse files Browse the repository at this point in the history
  • Loading branch information
xingzhongyu committed Jan 14, 2025
1 parent 9de47a1 commit f23b1ca
Showing 1 changed file with 56 additions and 27 deletions.
83 changes: 56 additions & 27 deletions examples/atlas/sc_similarity_examples/example_usage_anndata.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from torch.utils.data import TensorDataset

from dance.atlas.sc_similarity.anndata_similarity import AnnDataSimilarity, get_anndata
from dance.settings import DANCEDIR, METADIR
from dance.utils import set_seed

# target_files = [
Expand All @@ -25,8 +26,8 @@
# "eeacb0c1-2217-4cf6-b8ce-1f0fedf1b569"
# ]
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("--tissue", type=str, default="blood")
parser.add_argument("--data_dir", default="../../tuning/temp_data")
parser.add_argument("--tissue", type=str, default="heart")
parser.add_argument("--data_dir", default=DANCEDIR / f"examples/tuning/temp_data")
args = parser.parse_args()

data_dir = args.data_dir
Expand All @@ -35,8 +36,8 @@
tissue = args.tissue
# conf_data = pd.read_csv(f"results/{tissue}_result.csv", index_col=0)
conf_data = pd.read_excel("Cell Type Annotation Atlas.xlsx", sheet_name=tissue)
target_files = list(conf_data[conf_data["queryed"] == False]["dataset_id"])
source_files = list(conf_data[conf_data["queryed"] == True]["dataset_id"])
atlas_datasets = list(conf_data[conf_data["queryed"] == False]["dataset_id"])
query_datasets = list(conf_data[conf_data["queryed"] == True]["dataset_id"])


class CustomEncoder(json.JSONEncoder):
Expand Down Expand Up @@ -117,46 +118,74 @@ def run_test_case(source_file):
"""
ans = {}
for target_file in target_files:
source_data = get_anndata(train_dataset=[f"{source_file}"], data_dir=data_dir, tissue=tissue.capitalize())

for target_file in atlas_datasets:
# source_data=sc.read_h5ad(f"{data_root}/{source_file}.h5ad")
# target_data=sc.read_h5ad(f"{data_root}/{target_file}.h5ad")
source_data = get_anndata(train_dataset=[f"{source_file}"], data_dir=data_dir)
target_data = get_anndata(train_dataset=[f"{target_file}"], data_dir=data_dir)
target_data = get_anndata(train_dataset=[f"{target_file}"], data_dir=data_dir, tissue=tissue.capitalize())

# Initialize similarity calculator with multiple metrics
similarity_calculator = AnnDataSimilarity(adata1=source_data, adata2=target_data, sample_size=10,
init_random_state=42, n_runs=1,
ground_truth_conf_path="Cell Type Annotation Atlas.xlsx",
adata1_name=source_file, adata2_name=target_file)
adata1_name=source_file, adata2_name=target_file, tissue=tissue)

# Calculate similarity using multiple methods
ans[target_file] = similarity_calculator.get_similarity_matrix_A2B(methods=[
"wasserstein", "Hausdorff", "chamfer", "energy", "sinkhorn2", "bures", "spectral", "common_genes_num",
"ground_truth", "mmd", "metadata_sim"
"wasserstein",
"Hausdorff",
"chamfer",
"energy",
"sinkhorn2",
"bures",
"spectral",
"common_genes_num",
# "ground_truth",
"mmd",
"metadata_sim"
])

# Convert results to DataFrame and save
ans = pd.DataFrame(ans)
ans.to_csv(f'sim_{source_file}.csv')
ans_to_path = f'sims/{tissue}/sim_{source_file}.csv'
os.makedirs(os.path.dirname(ans_to_path), exist_ok=True)
ans.to_csv(ans_to_path)
return ans


query_data = os.listdir(file_root / "query_data")
with pd.ExcelWriter(file_root / f"{tissue}_similarity.xlsx", engine='openpyxl') as writer:
for source_file in source_files:
query_ans = [
pd.read_csv(file_root / "query_data" / element, index_col=0) for element in query_data
if element.split("_")[-3] == source_file
]
ans = run_test_case(source_file)
merged_df = pd.concat(query_ans + [ans], join='inner')
try:
start = False
query_data = os.listdir(file_root / "in_atlas_datas" / f"{tissue}")
excel_path = file_root / f"{tissue}_similarity.xlsx"
# with pd.ExcelWriter(file_root / f"{tissue}_similarity.xlsx", engine='openpyxl') as writer:
for source_file in query_datasets:
# if source_file[:4]=='c777':
# start=True
# if not start:
# continue
query_ans = pd.concat([
pd.read_csv(file_root / "in_atlas_datas" / f"{tissue}" / element, index_col=0) for element in query_data
if element.split("_")[-3] == source_file
])
rename_dict = {col: col.replace('_from_cache', '') for col in query_ans.columns if '_from_cache' in col}
query_ans = query_ans.rename(columns=rename_dict)
ans = run_test_case(source_file)
merged_df = pd.concat([query_ans, ans], join='inner')
if os.path.exists(excel_path):
excel = pd.ExcelFile(excel_path, engine='openpyxl')
if source_file[:4] in excel.sheet_names:
# 尝试读取指定的分表
existing_df = pd.read_excel(file_root / f"{tissue}_similarity.xlsx", sheet_name=source_file[:4])
existing_df = pd.read_excel(file_root / f"{tissue}_similarity.xlsx", sheet_name=source_file[:4],
engine="openpyxl", index_col=0)
# 找出在新数据框中存在但在现有表格中不存在的行
merged_df = pd.concat([existing_df, merged_df])
merged_df = merged_df.drop_duplicates(keep='first')
# 使用 ExcelWriter 更新特定分表
merged_df.to_excel(writer, sheet_name=source_file[:4], index=False)
except ValueError:
merged_df.to_excel(writer, sheet_name=source_file[:4], index=True)
merged_df = merged_df.drop_duplicates(subset=merged_df.index.name, keep='last')
excel.close()
if os.path.exists(excel_path):
mode = 'a'
if_sheet_exists = "replace"
else:
mode = 'w'
if_sheet_exists = None
with pd.ExcelWriter(excel_path, engine='openpyxl', mode=mode, if_sheet_exists=if_sheet_exists) as writer:
merged_df.to_excel(writer, sheet_name=source_file[:4])

0 comments on commit f23b1ca

Please sign in to comment.