diff --git a/examples/atlas/sc_similarity_examples/example_usage_anndata.py b/examples/atlas/sc_similarity_examples/example_usage_anndata.py index 00c6a63d..42579be5 100644 --- a/examples/atlas/sc_similarity_examples/example_usage_anndata.py +++ b/examples/atlas/sc_similarity_examples/example_usage_anndata.py @@ -12,6 +12,7 @@ from torch.utils.data import TensorDataset from dance.atlas.sc_similarity.anndata_similarity import AnnDataSimilarity, get_anndata +from dance.settings import DANCEDIR, METADIR from dance.utils import set_seed # target_files = [ @@ -25,8 +26,8 @@ # "eeacb0c1-2217-4cf6-b8ce-1f0fedf1b569" # ] parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) -parser.add_argument("--tissue", type=str, default="blood") -parser.add_argument("--data_dir", default="../../tuning/temp_data") +parser.add_argument("--tissue", type=str, default="heart") +parser.add_argument("--data_dir", default=DANCEDIR / f"examples/tuning/temp_data") args = parser.parse_args() data_dir = args.data_dir @@ -35,8 +36,8 @@ tissue = args.tissue # conf_data = pd.read_csv(f"results/{tissue}_result.csv", index_col=0) conf_data = pd.read_excel("Cell Type Annotation Atlas.xlsx", sheet_name=tissue) -target_files = list(conf_data[conf_data["queryed"] == False]["dataset_id"]) -source_files = list(conf_data[conf_data["queryed"] == True]["dataset_id"]) +atlas_datasets = list(conf_data[conf_data["queryed"] == False]["dataset_id"]) +query_datasets = list(conf_data[conf_data["queryed"] == True]["dataset_id"]) class CustomEncoder(json.JSONEncoder): @@ -117,46 +118,74 @@ def run_test_case(source_file): """ ans = {} - for target_file in target_files: + source_data = get_anndata(train_dataset=[f"{source_file}"], data_dir=data_dir, tissue=tissue.capitalize()) + + for target_file in atlas_datasets: # source_data=sc.read_h5ad(f"{data_root}/{source_file}.h5ad") # target_data=sc.read_h5ad(f"{data_root}/{target_file}.h5ad") - source_data = get_anndata(train_dataset=[f"{source_file}"], data_dir=data_dir) - target_data = get_anndata(train_dataset=[f"{target_file}"], data_dir=data_dir) + target_data = get_anndata(train_dataset=[f"{target_file}"], data_dir=data_dir, tissue=tissue.capitalize()) # Initialize similarity calculator with multiple metrics similarity_calculator = AnnDataSimilarity(adata1=source_data, adata2=target_data, sample_size=10, init_random_state=42, n_runs=1, ground_truth_conf_path="Cell Type Annotation Atlas.xlsx", - adata1_name=source_file, adata2_name=target_file) + adata1_name=source_file, adata2_name=target_file, tissue=tissue) # Calculate similarity using multiple methods ans[target_file] = similarity_calculator.get_similarity_matrix_A2B(methods=[ - "wasserstein", "Hausdorff", "chamfer", "energy", "sinkhorn2", "bures", "spectral", "common_genes_num", - "ground_truth", "mmd", "metadata_sim" + "wasserstein", + "Hausdorff", + "chamfer", + "energy", + "sinkhorn2", + "bures", + "spectral", + "common_genes_num", + # "ground_truth", + "mmd", + "metadata_sim" ]) # Convert results to DataFrame and save ans = pd.DataFrame(ans) - ans.to_csv(f'sim_{source_file}.csv') + ans_to_path = f'sims/{tissue}/sim_{source_file}.csv' + os.makedirs(os.path.dirname(ans_to_path), exist_ok=True) + ans.to_csv(ans_to_path) return ans -query_data = os.listdir(file_root / "query_data") -with pd.ExcelWriter(file_root / f"{tissue}_similarity.xlsx", engine='openpyxl') as writer: - for source_file in source_files: - query_ans = [ - pd.read_csv(file_root / "query_data" / element, index_col=0) for element in query_data - if element.split("_")[-3] == source_file - ] - ans = run_test_case(source_file) - merged_df = pd.concat(query_ans + [ans], join='inner') - try: +start = False +query_data = os.listdir(file_root / "in_atlas_datas" / f"{tissue}") +excel_path = file_root / f"{tissue}_similarity.xlsx" +# with pd.ExcelWriter(file_root / f"{tissue}_similarity.xlsx", engine='openpyxl') as writer: +for source_file in query_datasets: + # if source_file[:4]=='c777': + # start=True + # if not start: + # continue + query_ans = pd.concat([ + pd.read_csv(file_root / "in_atlas_datas" / f"{tissue}" / element, index_col=0) for element in query_data + if element.split("_")[-3] == source_file + ]) + rename_dict = {col: col.replace('_from_cache', '') for col in query_ans.columns if '_from_cache' in col} + query_ans = query_ans.rename(columns=rename_dict) + ans = run_test_case(source_file) + merged_df = pd.concat([query_ans, ans], join='inner') + if os.path.exists(excel_path): + excel = pd.ExcelFile(excel_path, engine='openpyxl') + if source_file[:4] in excel.sheet_names: # 尝试读取指定的分表 - existing_df = pd.read_excel(file_root / f"{tissue}_similarity.xlsx", sheet_name=source_file[:4]) + existing_df = pd.read_excel(file_root / f"{tissue}_similarity.xlsx", sheet_name=source_file[:4], + engine="openpyxl", index_col=0) # 找出在新数据框中存在但在现有表格中不存在的行 merged_df = pd.concat([existing_df, merged_df]) - merged_df = merged_df.drop_duplicates(keep='first') - # 使用 ExcelWriter 更新特定分表 - merged_df.to_excel(writer, sheet_name=source_file[:4], index=False) - except ValueError: - merged_df.to_excel(writer, sheet_name=source_file[:4], index=True) + merged_df = merged_df.drop_duplicates(subset=merged_df.index.name, keep='last') + excel.close() + if os.path.exists(excel_path): + mode = 'a' + if_sheet_exists = "replace" + else: + mode = 'w' + if_sheet_exists = None + with pd.ExcelWriter(excel_path, engine='openpyxl', mode=mode, if_sheet_exists=if_sheet_exists) as writer: + merged_df.to_excel(writer, sheet_name=source_file[:4])