-
Notifications
You must be signed in to change notification settings - Fork 186
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Example] Add preformer for precipitation nowcasting #976
base: develop
Are you sure you want to change the base?
Conversation
Thanks for your contribution! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
感谢提交PR,有几处小问题麻烦看一下
docs/zh/examples/preformer.md
Outdated
|
||
``` sh | ||
# 模型训练 | ||
python examples/preformer/train.py |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
python examples/preformer/train.py | |
python train.py |
docs/zh/examples/preformer.md
Outdated
|
||
``` sh | ||
# 模型评估 | ||
python examples/preformer/train.py mode=eval |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
python examples/preformer/train.py mode=eval | |
python train.py mode=eval |
examples/preformer/train.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
文件建议改名为main.py
examples/preformer/train.py
Outdated
# set random seed for reproducibility | ||
ppsci.utils.misc.set_random_seed(cfg.seed) | ||
# initialize logger | ||
logger.init_logger("ppsci", osp.join(cfg.output_dir, "train.log"), "info") | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
删除
examples/preformer/train.py
Outdated
"num_replicas": NUM_GPUS_PER_NODE, | ||
"rank": dist.get_rank() % NUM_GPUS_PER_NODE, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这两个参数应该不需要,并且paddlescience也没有对应的处理逻辑,默认会根据环境中设置的卡数自动设置
ppsci/data/dataset/era5sq_dataset.py
Outdated
mon = str("0") + mon | ||
day = str(self.time_table[idxs].timetuple().tm_mday) | ||
if len(day) == 1: | ||
day = str("0") + day |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
str("0")是否可以直接写成"0"?,下同
ppsci/data/dataset/era5sq_dataset.py
Outdated
r_data = np.load( | ||
os.path.join(self.file_path, year, "r_" + year + mon + day + hour + ".npy") | ||
) | ||
t_data = np.load( | ||
os.path.join(self.file_path, year, "t_" + year + mon + day + hour + ".npy") | ||
) | ||
u_data = np.load( | ||
os.path.join(self.file_path, year, "u_" + year + mon + day + hour + ".npy") | ||
) | ||
v_data = np.load( | ||
os.path.join(self.file_path, year, "v_" + year + mon + day + hour + ".npy") | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
可以直接使用f-string化简字符串拼接的写法
examples/preformer/conf/train.yaml
Outdated
hydra: | ||
run: | ||
# dynamic output directory according to running time and override name | ||
dir: outputs_preformer | ||
job: | ||
name: ${mode} # name of logfile | ||
chdir: false # keep current working directory unchanged | ||
config: | ||
override_dirname: | ||
exclude_keys: | ||
- TRAIN.checkpoint_path | ||
- TRAIN.trained_model_path | ||
- EVAL.trained_model_path | ||
- mode | ||
- output_dir | ||
- log_freq | ||
sweep: | ||
# output directory for multirun | ||
dir: ${hydra.run.dir} | ||
subdir: ./ | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hydra: | |
run: | |
# dynamic output directory according to running time and override name | |
dir: outputs_preformer | |
job: | |
name: ${mode} # name of logfile | |
chdir: false # keep current working directory unchanged | |
config: | |
override_dirname: | |
exclude_keys: | |
- TRAIN.checkpoint_path | |
- TRAIN.trained_model_path | |
- EVAL.trained_model_path | |
- mode | |
- output_dir | |
- log_freq | |
sweep: | |
# output directory for multirun | |
dir: ${hydra.run.dir} | |
subdir: ./ | |
defaults: | |
- ppsci_default | |
- TRAIN: train_default | |
- TRAIN/ema: ema_default | |
- TRAIN/swa: swa_default | |
- EVAL: eval_default | |
- INFER: infer_default | |
- hydra/job/config/override_dirname/exclude_keys: exclude_keys_default | |
- _self_ | |
hydra: | |
run: | |
# dynamic output directory according to running time and override name | |
dir: outputs_preformer | |
job: | |
name: ${mode} # name of logfile | |
chdir: false # keep current working directory unchanged | |
sweep: | |
# output directory for multirun | |
dir: ${hydra.run.dir} | |
subdir: ./ | |
examples/preformer/conf/train.yaml
Outdated
|
||
# model settings | ||
MODEL: | ||
afno: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
单模型可以删除afno
这一层级
examples/preformer/conf/train.yaml
Outdated
afno: | ||
input_keys: ["input"] | ||
output_keys: ["output"] | ||
shape_in: [6, 12, IMG_H, IMG_W] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shape_in: [6, 12, IMG_H, IMG_W] | |
shape_in: | |
- 6 | |
- 12 | |
- ${IMG_H} | |
- ${IMG_W} | |
@EricKing19 标题已经修改过了,原先的merge code of upstream不太合适 |
docs/zh/examples/preformer.md
Outdated
案例中使用了预处理的 PEMSD4 和 PEMSD8 数据集。PEMSD4 为旧金山湾区交通数据,选取 29 条道路上 307 个传感器记录的交通数据,时间为 2018 年 1 月至 2 月。PEMSD8 为圣贝纳迪诺 8 条道路上 170 个检测器收集的交通数据,时间为 2016 年 7 月至 8 月。 | ||
|
||
两个数据集均被保存为 N x T x 1 的矩阵,记录了相应交通节点与时间的流量数据,其中 N 为交通节点数量,T 为时间序列长度。两个数据集分别按照 7:2:1 划分为训练集、验证集,和测试集。案例中预先计算了流量数据的均值与标准差,用于后续的正则化操作。 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
该案例是关于降水的,这个数据集好像是交通的,数据集与代码不一致
开始训练、评估前,请下载数据集文件 | ||
|
||
开始评估前,请下载或训练生成预训练模型 | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
可以稍微介绍一下数据集的准备过程吗?比如如何下载和解压后的文件组织形式?
docs/zh/examples/preformer.md
Outdated
=== "模型训练命令" | ||
|
||
``` sh | ||
# 模型训练 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
删除这个注释,上面这个标签已经说明了这是模型训练命令了
docs/zh/examples/preformer.md
Outdated
=== "模型评估命令" | ||
|
||
``` sh | ||
# 模型评估 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上,删除该行注释
docs/zh/examples/preformer.md
Outdated
|
||
``` sh | ||
# 模型评估 | ||
python train.py mode=eval |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里麻烦提供一下您训练好的预训练模型文件(.pdparams
文件即可),我们上传到bce上,这样就能通过在命令里直接指定预训练模型url直接下载并在评估前自动加载权重,不需要额外的手动下载了
docs/zh/examples/preformer.md
Outdated
#### 3.2.6 模型导出 | ||
|
||
通过设置 `ppsci.solver.Solver` 中的 `eval_during_train` 和 `eval_freq` 参数,可以自动保存在验证集上效果最优的模型参数。 | ||
|
||
``` py linenums="100" title="examples/preformer/train.py" | ||
--8<-- | ||
examples/preformer/train.py:158:158 | ||
--8<-- | ||
``` | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- 模型导出章节可以不用出现在文章中,删除
- 请补充模型导出的函数
def export
和def inference
到examples\preformer\main.py
中,参考:PaddleScience/examples/allen_cahn/allen_cahn_piratenet.py
Lines 235 to 269 in 83f6739
def export(cfg: DictConfig): # set model model = ppsci.arch.PirateNet(**cfg.MODEL) # initialize solver solver = ppsci.solver.Solver(model, cfg=cfg) # export model from paddle.static import InputSpec input_spec = [ {key: InputSpec([None, 1], "float32", name=key) for key in model.input_keys}, ] solver.export(input_spec, cfg.INFER.export_path, with_onnx=False) def inference(cfg: DictConfig): from deploy.python_infer import pinn_predictor predictor = pinn_predictor.PINNPredictor(cfg) data = sio.loadmat(cfg.DATA_PATH) u_ref = data["usol"].astype(dtype) # (nt, nx) t_star = data["t"].flatten().astype(dtype) # [nt, ] x_star = data["x"].flatten().astype(dtype) # [nx, ] tx_star = misc.cartesian_product(t_star, x_star).astype(dtype) input_dict = {"t": tx_star[:, 0:1], "x": tx_star[:, 1:2]} output_dict = predictor.predict(input_dict, cfg.INFER.batch_size) # mapping data to cfg.INFER.output_keys output_dict = { store_key: output_dict[infer_key] for store_key, infer_key in zip(cfg.MODEL.output_keys, output_dict.keys()) } u_pred = output_dict["u"].reshape([len(t_star), len(x_star)]) plot(t_star, x_star, u_ref, u_pred, cfg.output_dir) - 模型导出和模型推理执行命令请添加到文档开头处的"=== "模型评估命令""后面
return latent | ||
|
||
|
||
class Mid_Xnet(nn.Layer): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mid_Xnet建议改为MidXNet,命名更规范
def forward(self, hid, enc1=None): | ||
for i in range(0, len(self.dec)): | ||
hid = self.dec[i](hid) | ||
# Y = self.dec[-1](torch.cat([hid, enc1], dim=1)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这行注释是否可以删除 ?
for m in range(self.sq_length): | ||
x.append(self.load_data(global_idx + m)) | ||
for n in range(self.sq_length): | ||
# y.append(self.load_data(global_idx+n)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这行注释是否可以删除?
# y.append(self.load_data(global_idx+n)) | ||
y.append(self.precipitation["tp"][global_idx + self.sq_length + n]) | ||
# x = self.Normalize(x) | ||
x, y = self.RandomCrop(x, y) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self.RandomCrop是否应该是self._random_crop?
def _random_crop(self, x, y): | ||
if isinstance(self.size, numbers.Number): | ||
self.size = (int(self.size), int(self.size)) | ||
th, tw = self.size | ||
h, w = y[0].shape[-2], y[0].shape[-1] | ||
x1 = random.randint(0, w - tw) | ||
y1 = random.randint(0, h - th) | ||
|
||
for i in range(len(x)): | ||
x[i] = self.crop(x[i], y1, x1, y1 + th, x1 + tw) | ||
for i in range(len(y)): | ||
y[i] = self.crop(y[i], y1, x1, y1 + th, x1 + tw) | ||
|
||
return x, y | ||
|
||
def crop(self, im, x_start, y_start, x_end, y_end): | ||
if len(im.shape) == 3: | ||
return im[:, x_start:x_end, y_start:y_end] | ||
else: | ||
return im[x_start:x_end, y_start:y_end] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
非公开方法前面建议加上下划线:
def _random_crop(self, x, y): | |
if isinstance(self.size, numbers.Number): | |
self.size = (int(self.size), int(self.size)) | |
th, tw = self.size | |
h, w = y[0].shape[-2], y[0].shape[-1] | |
x1 = random.randint(0, w - tw) | |
y1 = random.randint(0, h - th) | |
for i in range(len(x)): | |
x[i] = self.crop(x[i], y1, x1, y1 + th, x1 + tw) | |
for i in range(len(y)): | |
y[i] = self.crop(y[i], y1, x1, y1 + th, x1 + tw) | |
return x, y | |
def crop(self, im, x_start, y_start, x_end, y_end): | |
if len(im.shape) == 3: | |
return im[:, x_start:x_end, y_start:y_end] | |
else: | |
return im[x_start:x_end, y_start:y_end] | |
def _random_crop(self, x, y): | |
if isinstance(self.size, numbers.Number): | |
self.size = (int(self.size), int(self.size)) | |
th, tw = self.size | |
h, w = y[0].shape[-2], y[0].shape[-1] | |
x1 = random.randint(0, w - tw) | |
y1 = random.randint(0, h - th) | |
for i in range(len(x)): | |
x[i] = self._crop(x[i], y1, x1, y1 + th, x1 + tw) | |
for i in range(len(y)): | |
y[i] = self._crop(y[i], y1, x1, y1 + th, x1 + tw) | |
return x, y | |
def _crop(self, im, x_start, y_start, x_end, y_end): | |
if len(im.shape) == 3: | |
return im[:, x_start:x_end, y_start:y_end] | |
else: | |
return im[x_start:x_end, y_start:y_end] |
@EricKing19 顺带解决一下冲突问题 |
PR types
Others
PR changes
Others
Describe
add Preformer model for precipitation nowcasting
add docs for Preformer
add examples for Preformer