Bohrium
robot
新建

空间站广场

论文
Notebooks
比赛
课程
Apps
我的主页
我的Notebooks
我的论文库
我的足迹

我的工作空间

任务
节点
文件
数据集
镜像
项目
数据库
公开
神经算子方法求解Helmholtz问题——Fourier Neural Operator
Deep Learning
AI4S
Helmholtz
PDE
Deep LearningAI4SHelmholtzPDE
曾祉竣
更新于 2024-06-30
推荐镜像 :Basic Image:ubuntu:22.04-py3.10-pytorch2.0
推荐机型 :c3_m4_1 * NVIDIA T4
赞 6
35
26
Helmholtz_Single_Data(v1)

基于FNO的波场预测框架介绍

©️ Copyright 2024 @ Authors
作者:曾祉竣 📨
日期:2024-02-15
共享协议:本作品采用知识共享署名-非商业性使用-相同方式共享 4.0 国际许可协议进行许可。
快速开始:你可以点击界面上方蓝色按钮 开始连接 ,选择 `公共镜像ubuntu:22.04-py3.10-pytorch2.0

</span>及<span style='color:rgb(85,91,228); font-weight:bold'>c3_m4_1 * NVIDIA</span>节点配置,加载<span style='color:rgb(85,91,228); font-weight:bold'>Helmholtz-single-dataset`数据集. 稍等片刻即可运行。
AI4SCUP赛事说明: 本案例仅供选手参考,帮助选手理解问题设定与神经算子方法,为选手制定自己的方法提供灵感。本示例展示的数据与赛题数据无关。

代码
文本

一、超声CT及神经网络方法简介

1.超声CT物理原理

超声CT是一项新兴的高分辨率临床成像技术,具有低成本、无辐射等优点。超声CT场景中,观测的物体被放置在均匀的介质中(如水),在外部放置等角度间隔的传感器进行探测(如下图所示)。 超声CT中声波在空间中的传播到稳态时的波场分布可以通过Helmholtz方程进行建模 我们感兴趣的区域在一个有限区域(正方形),并且用一个固定的发射器作为波源,即 其中为波源位置,为波源强度。另外,我们能够控制发射器的模式,即已知波数。超声CT的正向模拟即给定区域内的波速分布,计算波场的分布

代码
文本
[3]
!pip install pytorch_lightning
Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple
Collecting pytorch_lightning
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/7a/e0/7ac399296f3f3edd7f6b3dd67ffcc4e4991e0611f27dded1accd5ad084b2/pytorch_lightning-2.3.1-py3-none-any.whl (812 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 812.3/812.3 kB 16.2 MB/s eta 0:00:00a 0:00:01
Collecting fsspec[http]>=2022.5.0
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/5e/44/73bea497ac69bafde2ee4269292fa3b41f1198f4bb7bbaaabde30ad29d4a/fsspec-2024.6.1-py3-none-any.whl (177 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 177.6/177.6 kB 26.2 MB/s eta 0:00:00
Requirement already satisfied: packaging>=20.0 in /opt/mamba/lib/python3.10/site-packages (from pytorch_lightning) (23.0)
Requirement already satisfied: tqdm>=4.57.0 in /opt/mamba/lib/python3.10/site-packages (from pytorch_lightning) (4.64.1)
Requirement already satisfied: torch>=2.0.0 in /opt/mamba/lib/python3.10/site-packages (from pytorch_lightning) (2.0.0+cu118)
Collecting lightning-utilities>=0.10.0
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/75/6c/ccad49b96b38758ac77b04dc6d3795fb460fe0f8f311c75d0af0f8085cfb/lightning_utilities-0.11.3.post0-py3-none-any.whl (26 kB)
Requirement already satisfied: PyYAML>=5.4 in /opt/mamba/lib/python3.10/site-packages (from pytorch_lightning) (6.0)
Requirement already satisfied: numpy>=1.17.2 in /opt/mamba/lib/python3.10/site-packages (from pytorch_lightning) (1.24.2)
Collecting torchmetrics>=0.7.0
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/6d/e6/e51997d1818a4c1a1ad2b1c7ca5ff9dd95969596add58b2ed39479026964/torchmetrics-1.4.0.post0-py3-none-any.whl (868 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 868.8/868.8 kB 31.2 MB/s eta 0:00:00
Requirement already satisfied: typing-extensions>=4.4.0 in /opt/mamba/lib/python3.10/site-packages (from pytorch_lightning) (4.5.0)
Collecting aiohttp!=4.0.0a0,!=4.0.0a1
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/a0/09/e7637f4f0760cad4d67347bbd8311c6ad0259a3fc01f04555af9e84bd378/aiohttp-3.9.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.2 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.2/1.2 MB 54.3 MB/s eta 0:00:00
Requirement already satisfied: setuptools in /opt/mamba/lib/python3.10/site-packages (from lightning-utilities>=0.10.0->pytorch_lightning) (65.5.0)
Requirement already satisfied: triton==2.0.0 in /opt/mamba/lib/python3.10/site-packages (from torch>=2.0.0->pytorch_lightning) (2.0.0)
Requirement already satisfied: sympy in /opt/mamba/lib/python3.10/site-packages (from torch>=2.0.0->pytorch_lightning) (1.11.1)
Requirement already satisfied: networkx in /opt/mamba/lib/python3.10/site-packages (from torch>=2.0.0->pytorch_lightning) (3.0)
Requirement already satisfied: filelock in /opt/mamba/lib/python3.10/site-packages (from torch>=2.0.0->pytorch_lightning) (3.10.0)
Requirement already satisfied: jinja2 in /opt/mamba/lib/python3.10/site-packages (from torch>=2.0.0->pytorch_lightning) (3.1.2)
Requirement already satisfied: cmake in /opt/mamba/lib/python3.10/site-packages (from triton==2.0.0->torch>=2.0.0->pytorch_lightning) (3.26.0)
Requirement already satisfied: lit in /opt/mamba/lib/python3.10/site-packages (from triton==2.0.0->torch>=2.0.0->pytorch_lightning) (15.0.7)
Collecting yarl<2.0,>=1.0
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/c3/a0/0ade1409d184cbc9e85acd403a386a7c0563b92ff0f26d138ff9e86e48b4/yarl-1.9.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (301 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 301.6/301.6 kB 42.9 MB/s eta 0:00:00
Collecting multidict<7.0,>=4.5
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/33/62/2c9085e571318d51212a6914566fe41dd0e33d7f268f7e2f23dcd3f06c56/multidict-6.0.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (124 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 124.3/124.3 kB 34.3 MB/s eta 0:00:00
Collecting aiosignal>=1.1.2
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/76/ac/a7305707cb852b7e16ff80eaf5692309bde30e2b1100a1fcacdc8f731d97/aiosignal-1.3.1-py3-none-any.whl (7.6 kB)
Collecting frozenlist>=1.1.1
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/ec/25/0c87df2e53c0c5d90f7517ca0ff7aca78d050a8ec4d32c4278e8c0e52e51/frozenlist-1.4.1-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (239 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 239.5/239.5 kB 42.7 MB/s eta 0:00:00
Collecting async-timeout<5.0,>=4.0
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/a7/fa/e01228c2938de91d47b307831c62ab9e4001e747789d0b05baf779a6488c/async_timeout-4.0.3-py3-none-any.whl (5.7 kB)
Requirement already satisfied: attrs>=17.3.0 in /opt/mamba/lib/python3.10/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]>=2022.5.0->pytorch_lightning) (22.2.0)
Requirement already satisfied: MarkupSafe>=2.0 in /opt/mamba/lib/python3.10/site-packages (from jinja2->torch>=2.0.0->pytorch_lightning) (2.1.2)
Requirement already satisfied: mpmath>=0.19 in /opt/mamba/lib/python3.10/site-packages (from sympy->torch>=2.0.0->pytorch_lightning) (1.3.0)
Requirement already satisfied: idna>=2.0 in /opt/mamba/lib/python3.10/site-packages (from yarl<2.0,>=1.0->aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]>=2022.5.0->pytorch_lightning) (3.4)
Installing collected packages: multidict, lightning-utilities, fsspec, frozenlist, async-timeout, yarl, aiosignal, aiohttp, torchmetrics, pytorch_lightning
Successfully installed aiohttp-3.9.5 aiosignal-1.3.1 async-timeout-4.0.3 frozenlist-1.4.1 fsspec-2024.6.1 lightning-utilities-0.11.3.post0 multidict-6.0.5 pytorch_lightning-2.3.1 torchmetrics-1.4.0.post0 yarl-1.9.4
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv
代码
文本
[4]
!pip install matplotlib
Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple
Collecting matplotlib
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/a7/68/16e7b9154fae61fb29f0f3450b39b855b89e6d2c598d67302e70f96883af/matplotlib-3.9.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (8.3 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 8.3/8.3 MB 51.1 MB/s eta 0:00:0000:0100:01
Collecting contourpy>=1.0.1
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/67/0f/6e5b4879594cd1cbb6a2754d9230937be444f404cf07c360c07a10b36aac/contourpy-1.2.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (305 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 305.2/305.2 kB 55.4 MB/s eta 0:00:00
Requirement already satisfied: packaging>=20.0 in /opt/mamba/lib/python3.10/site-packages (from matplotlib) (23.0)
Requirement already satisfied: python-dateutil>=2.7 in /opt/mamba/lib/python3.10/site-packages (from matplotlib) (2.8.2)
Collecting fonttools>=4.22.0
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/7a/d0/010c65f46fb14333cdb537566d1532e64361eb981180ab73f1148e927382/fonttools-4.53.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (4.6 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 4.6/4.6 MB 78.3 MB/s eta 0:00:00ta 0:00:01
Collecting pillow>=8
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/b5/a2/7a09695dc636bf8d0a1b63022f58701177b7dc6fad30f6d6bc343e5473a4/pillow-10.3.0-cp310-cp310-manylinux_2_28_x86_64.whl (4.5 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 4.5/4.5 MB 85.0 MB/s eta 0:00:00ta 0:00:01
Requirement already satisfied: numpy>=1.23 in /opt/mamba/lib/python3.10/site-packages (from matplotlib) (1.24.2)
Collecting cycler>=0.10
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/e7/05/c19819d5e3d95294a6f5947fb9b9629efb316b96de511b418c53d245aae6/cycler-0.12.1-py3-none-any.whl (8.3 kB)
Collecting kiwisolver>=1.3.1
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/6f/40/4ab1fdb57fced80ce5903f04ae1aed7c1d5939dda4fd0c0aa526c12fe28a/kiwisolver-1.4.5-cp310-cp310-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (1.6 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.6/1.6 MB 80.4 MB/s eta 0:00:00
Collecting pyparsing>=2.3.1
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/9d/ea/6d76df31432a0e6fdf81681a895f009a4bb47b3c39036db3e1b528191d52/pyparsing-3.1.2-py3-none-any.whl (103 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 103.2/103.2 kB 29.6 MB/s eta 0:00:00
Requirement already satisfied: six>=1.5 in /opt/mamba/lib/python3.10/site-packages (from python-dateutil>=2.7->matplotlib) (1.16.0)
Installing collected packages: pyparsing, pillow, kiwisolver, fonttools, cycler, contourpy, matplotlib
Successfully installed contourpy-1.2.1 cycler-0.12.1 fonttools-4.53.0 kiwisolver-1.4.5 matplotlib-3.9.0 pillow-10.3.0 pyparsing-3.1.2
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv
代码
文本

二、相关三方库的介绍

库名 介绍
numpy 实现了类似于 C 的数组分配式使用并包装了很多常用的张量操作
torch 提供了一套动态图调用框架的神经网络软件框架,也是一套方便的 GPU 张量操作框架
pytorch-lightning PyTorch-Lightning 是一个开源的 PyTorch 加速框架,它旨在帮助研究人员和工程师更快地构建神经网络模型和训练过程。
matplotlib python 语言中使用便捷的画图库
代码
文本
[5]
from pytorch_lightning.callbacks import ModelCheckpoint,LearningRateMonitor
import pytorch_lightning as pl
import yaml
import argparse
from bisect import bisect
import os
import torch
import shutil
import warnings
import numpy as np
from torch.utils.data import Dataset, DataLoader
from collections import OrderedDict
import matplotlib.pyplot as plt
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
warnings.filterwarnings("ignore")
代码
文本

三、数据加载模块

代码
文本

我们通过以下步骤搭建Dataloader模块

  1. 读取训练、验证、测试集数据的路径(GettingLists类)
  2. 通过路径加载数据(File_Loader类),给定波速和波场的位置,加载波速和波场,并构建迭代器
  3. 获取给定index的输入和输出数据,并进行预处理
代码
文本
[6]
class File_Loader(Dataset):
#data loader file
def __init__(self, data_paths, target_paths, size =128):
self.size = size
# source
self.src = torch.tensor(np.load('/bohr/breastsingle-hqot/v1/breast/u_homo.npy'))*2e-3
# Loading the wave speed data
print('Start Loading velocity')
self.data_memmaps = [np.load(path, mmap_mode='r') for path in data_paths]
print('Loading Velocity Done')
# Loading the wavefield data
print('Start Loading Wavefields')
self.target_memmaps = [np.load(path, mmap_mode='r') for path in target_paths]
print('Loading Wavefield Done')
# Calculate the index
self.start_indices = [0] * len(data_paths)
self.data_count = 0
for index, memmap in enumerate(self.data_memmaps):
self.start_indices[index] = self.data_count
if len(memmap.shape) == 3:
self.data_count += 1
elif len(memmap.shape) == 4:
self.data_count += memmap.shape[0]
self.start_indices_target = [0] * len(target_paths)
self.data_count_target = 0
for index, memmap in enumerate(self.target_memmaps):
self.start_indices_target[index] = self.data_count_target
self.data_count_target += memmap.shape[0]
def __len__(self):
return self.data_count_target
def __getitem__(self, index):
# Locate the matrix
memmap_index = bisect(self.start_indices_target, index) - 1
# Locate the index in matrix
index_in_memmap = index - self.start_indices_target[memmap_index] #0-8

#Preprocessing
data = (1500/np.copy(self.data_memmaps[memmap_index][0,:,:])-1)*30
target = np.copy(self.target_memmaps[memmap_index][0,:,:])*2e-3
# Split the real and complex part
target = np.concatenate((np.real(target)[:,:,np.newaxis],np.imag(target)[:,:,np.newaxis]),axis = -1)
src = self.src[0,:,:] #torch.cat((self.src[0,:,:],torch.zeros_like(self.src[0,:,:])),dim = -1)
return torch.tensor(data, dtype=torch.float).view(self.size,self.size,1),torch.tensor(src,dtype = torch.float).view(self.size,self.size,2), torch.tensor(target, dtype=torch.float).view(self.size,self.size,2)

代码
文本
[7]
class GettingLists(object):
def __init__(self,train_num = 1600,
valid_num = 400 ,
PATH = 'lbs',
batchsize= int(2000)):
super(GettingLists, self).__init__()
self.PATH = PATH # Data Path
self.batchsize = batchsize
self.valid_num = valid_num # validation num
self.velo_list_train = [i for i in range(1, train_num+1)]
self.velo_list_test = [i for i in range(train_num+1, train_num+valid_num)]
self.pressure_list_train = self.velo_list_train
self.pressure_list_test = self.velo_list_test
def get_list(self, do):
if do == 'train':
in_limit_train = np.array([os.path.join(self.PATH,
'model',
f'{k}.npy') for k in \
self.velo_list_train])
out_limit_train = np.array([os.path.join(self.PATH,
'data',
f'pressure{k}.npy')for k in \
self.pressure_list_train])
return in_limit_train, out_limit_train
elif do == 'validation':
in_limit_valid = np.array([os.path.join(self.PATH,
'model',
f'{k}.npy') for k in \
self.velo_list_test])
out_limit_valid= np.array([os.path.join(self.PATH,
'data',
f'pressure{k}.npy') for k in \
self.pressure_list_test])
return in_limit_valid, out_limit_valid
def __call__(self, do = 'train'):
# obtain the list of path of dataset
return self.get_list(do)
def get_dataloader(self,do,workers,size):
batchsize = self.batchsize
if do == 'train':
list_x_train, list_y_train = self.__call__('train')
list_x_valid, list_y_valid = self.__call__('validation')
Train_Data_set = File_Loader(list_x_train,list_y_train, size = size)
Valid_Data_set = File_Loader(list_x_valid,list_y_valid, size = size)
train_loader = DataLoader(dataset = Train_Data_set,
shuffle = True,
batch_size = batchsize,
num_workers= workers)
valid_loader = DataLoader(dataset = Valid_Data_set,
shuffle = False,
batch_size =batchsize,
num_workers= workers)
return train_loader, valid_loader

train_num = 1400
valid_num = 200
PATH = '/bohr/breastsingle-hqot/v1/breast'
batchsize = 10
workers = 10
size = 480
gl = GettingLists(train_num = train_num,
valid_num = valid_num,
PATH = PATH,
batchsize = batchsize)
train_loader, valid_loader = gl.get_dataloader(do = 'train',
workers = workers,
size = size)
Start Loading velocity
Loading Velocity Done
Start Loading Wavefields
Loading Wavefield Done
Start Loading velocity
Loading Velocity Done
Start Loading Wavefields
Loading Wavefield Done
代码
文本
[8]
with torch.no_grad():
num = train_loader.__len__()
print(num)
error = 0
for batch_idx, batch in enumerate(train_loader):
sos,src,y = batch
batchsize = sos.shape[0]
print(batchsize)
break
sos = sos[0,:]
src = src[0,:]
y = y[0,:]
fig, ax = plt.subplots(1, 3, figsize=(12, 4))
ax = ax.flatten()
ax0 = ax[0].imshow(sos[...,0], cmap="inferno")
ax[0].set_title("Sound speed")
ax[1].imshow(src[...,0], cmap="RdBu_r")
ax[1].set_title("Homogeneous Field Real")
ax[2].imshow(y[...,0], cmap="RdBu_r")
ax[2].set_title("Truth Wave field Real")
plt.show()
140
10
代码
文本

四、模型搭建

Fourier Neural Operator(FNO)的原理图如下:

image.png

FNO构建了一个从偏微分方程参数函数a(x)到解函数u(x)的映射, 输入a(x)经MLP(P)升维后经过T个傅里叶层,最后经MLP(Q)降维得到输出u(x)。
一个傅里叶层内的具体操作则包括傅里叶变换F,线性变换R,傅里叶逆变换F’等。

接下来我们可以结合代码进一步理解:

代码
文本

首先我们搭建一些基础模块,比如全连接网络(含dropout)以及LayerNorm等

代码
文本
[9]
# Fully Connected Layer
class FCLayer(nn.Module):
"""Fully connected layer """
def __init__(self, in_feature, out_feature,
activation = "gelu",
is_normalized = True):
super().__init__()
if is_normalized:
self.LinearBlock = nn.Sequential(
nn.Linear(in_feature,out_feature),
LayerNorm(out_feature),
)
else:
self.LinearBlock = nn.Linear(in_feature,out_feature)
if activation:
self.act = F.gelu
else:
self.act = nn.Identity()
def forward(self, x):
return self.act(self.LinearBlock(x))

#####################################################################
# Fully Connected Neural Networks
class FC_nn(nn.Module):
r"""Simple MLP to code lifting and projection"""
def __init__(self, sizes = [2, 128, 128, 1],
outermost_linear = True,
outermost_norm = True,
drop = 0.):
super().__init__()
self.dropout = nn.Dropout(drop)
self.net = nn.ModuleList([FCLayer(in_feature= m, out_feature= n,
activation='gelu',
is_normalized = False)
for m, n in zip(sizes[:-2], sizes[1:-1])
])
if outermost_linear == True:
self.net.append(FCLayer(sizes[-2],sizes[-1], activation = None,
is_normalized = outermost_norm))
else:
self.net.append(FCLayer(in_feature= sizes[-2], out_feature= sizes[-1],
activation='gelu',
is_normalized = outermost_norm))

def forward(self,x):
for module in self.net:
x = module(x)
x = self.dropout(x)
return x
#####################################################################
# LayerNorm Module
class LayerNorm(nn.Module):
r""" LayerNorm that supports two data formats: channels_last (default) or channels_first.
The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
shape (batch_size, height, width, channels) while channels_first corresponds to inputs
with shape (batch_size, channels, height, width).
"""
def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
super().__init__()
self.weight = nn.Parameter(torch.ones(normalized_shape))
self.bias = nn.Parameter(torch.zeros(normalized_shape))
self.eps = eps
self.data_format = data_format
if self.data_format not in ["channels_last", "channels_first"]:
raise NotImplementedError
self.normalized_shape = (normalized_shape, )
def forward(self, x):
if self.data_format == "channels_last":
return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
elif self.data_format == "channels_first":
u = x.mean(1, keepdim=True)
s = (x - u).pow(2).mean(1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.eps)
x = self.weight[:, None, None] * x + self.bias[:, None, None]
return x
#####################################################################
# Getting the 2D grid using the batch
def get_grid2D(shape, device):
batchsize, size_x, size_y = shape[0], shape[1], shape[2]
gridx = torch.tensor(np.linspace(0, 1, size_x), dtype=torch.float)
gridx = gridx.reshape(1, size_x, 1, 1).repeat([batchsize, 1, size_y, 1])
gridy = torch.tensor(np.linspace(0, 1, size_y), dtype=torch.float)
gridy = gridy.reshape(1, 1, size_y, 1).repeat([batchsize, size_x, 1, 1])
return torch.cat((gridx, gridy), dim=-1).to(device)
代码
文本

其次我们搭建基于Pytorch Lightning 的Fourier Neural Operator。

fourier_conv_2d类为傅里叶卷积操作,对应FNO原理图(b)中的黄色框。 其中:

In_ , out_代表in_channels与out_channels
wavenumber1, wavenumber2代表两个维度下傅里叶mode的数量
weights1, weights2代表权重
compl_mul2d定义向量乘的规则,即定义input和weights如何做乘
接下来forward中分别做了傅里叶变换、乘权重并截断、傅里叶逆变换。

代码
文本
[10]
#fourier convolution 2d block
class fourier_conv_2d(nn.Module):
def __init__(self, in_, out_, wavenumber1, wavenumber2):
super(fourier_conv_2d, self).__init__()
self.out_ = out_
self.wavenumber1 = wavenumber1
self.wavenumber2 = wavenumber2
scale = (1 / (in_ * out_))
self.weights1 = nn.Parameter(scale * torch.rand(in_, out_, wavenumber1, wavenumber2, 2 , dtype=torch.float32))
self.weights2 = nn.Parameter(scale * torch.rand(in_, out_, wavenumber1, wavenumber2, 2 , dtype=torch.float32))
# Complex multiplication
def compl_mul2d(self, input, weights):
# (batch, in_channel, x,y ,2), (in_channel, out_channel, x,y,2) -> (batch, out_channel, x,y)
return torch.einsum("bixyz,ioxyz->boxyz", input, weights)
def forward(self, x):
#input: batch,channel,x,y
#out: batch,channel,x,y
batchsize = x.shape[0]
#Compute Fourier coeffcients up to factor of e^(- something constant)
x_ft = torch.view_as_real(torch.fft.rfft2(x))#input: batch,channel,x,y->batch,channel,x,y,2
# Multiply relevant Fourier modes
out_ft = torch.zeros(batchsize, self.out_, x.size(-2), x.size(-1)//2 + 1,2, dtype=torch.float32, device=x.device)
out_ft[:, :, :self.wavenumber1, :self.wavenumber2,:] = \
self.compl_mul2d(x_ft[:, :, :self.wavenumber1, :self.wavenumber2,:], self.weights1)
out_ft[:, :, -self.wavenumber1:, :self.wavenumber2,:] = \
self.compl_mul2d(x_ft[:, :, -self.wavenumber1:, :self.wavenumber2,:], self.weights2)
#Return to physical space
x = torch.fft.irfft2(torch.view_as_complex(out_ft), s=(x.size(-2), x.size(-1)))
return x
代码
文本

Fourier_layer即傅里叶层

代码
文本
[11]
#fourier convolution layer using fourier conv block and conv2d block
class Fourier_layer(nn.Module):
def __init__(self, features_, wavenumber, is_last = False):
super(Fourier_layer, self).__init__()
self.W = nn.Conv2d(features_, features_, 1)
self.fourier_conv = fourier_conv_2d(features_, features_ , *wavenumber)
if is_last== False:
self.act = F.gelu
else:
self.act = nn.Identity()
def forward(self, x):
x1 = self.fourier_conv(x)
x2 = self.W(x)
return self.act(x1 + x2)
代码
文本

最终我们可以构建FNO类,其中的基本模块已在前文构建。
self.lifting对应原理图中的P,是简单的MLP网络encoder,负责升维;
self.proj对应原理图中的Q,是简单的MLP网络decoder,负责降维。

代码
文本
[12]
class FNO(pl.LightningModule):
def __init__(self,
wavenumber, features_,
padding = 9,
lifting = None,
proj = None,
dim_input = 1,
with_grid= True,
add_term = True,
learning_rate = 1e-2,
step_size= 100,
gamma= 0.5,
weight_decay= 1e-5,
eta_min = 5e-4):
super(FNO, self).__init__()
self.with_grid = with_grid
self.padding = padding
self.layers = len(wavenumber)
self.learning_rate = learning_rate
self.step_size = step_size
self.gamma = gamma
self.weight_decay = weight_decay
self.eta_min = eta_min
self.add_term = add_term
self.criterion = nn.MSELoss()
self.criterion_val = nn.MSELoss()
if with_grid == True:
dim_input+=4
self.lifting = FC_nn([dim_input, features_//2, features_],
outermost_norm=False
)
self.proj = FC_nn([features_, features_//2, 2],
outermost_norm=False
)
self.fno = []
for l in range(self.layers-1):
self.fno.append(Fourier_layer(features_ = features_,
wavenumber=[wavenumber[l]]*2))
self.fno.append(Fourier_layer(features_=features_,
wavenumber=[wavenumber[-1]]*2,
is_last= True))
self.fno =nn.Sequential(*self.fno)
def forward(self, sos, src):
# forward process of FNO
#x = torch.cat((sos, src), dim=-1)
x = sos
if self.with_grid == True:
grid = get_grid2D(x.shape, x.device)
x = torch.cat((x,src, grid), dim=-1)
x = self.lifting(x)
x = x.permute(0, 3, 1, 2)
x = nn.functional.pad(x, [0,self.padding, 0,self.padding])
x = self.fno(x)
x = x[..., :-self.padding, :-self.padding]
x = x.permute(0, 2, 3, 1 )
x =self.proj(x)
if self.add_term == True:
x = torch.view_as_real(torch.view_as_complex(src.to(x.device))*(1+torch.view_as_complex(x)))
return x

def training_step(self, batch: torch.Tensor, batch_idx):
# One step training
sos,src,y = batch
batch_size = sos.shape[0]
out = self(sos,src)
loss = self.criterion(out.view(batch_size,-1),y.view(batch_size,-1))
self.log("loss", loss, on_epoch=True, prog_bar=True, logger=True)
return loss

def validation_step(self, val_batch: torch.Tensor, batch_idx):
# One step validation
sos,src,y= val_batch
batch_size = sos.shape[0]
out = self(sos,src)
val_loss = self.criterion_val(out.view(batch_size,-1),y.view(batch_size,-1))
self.log('val_loss', val_loss, on_epoch=True, prog_bar=True, logger=True)
return val_loss
def configure_optimizers(self, optimizer=None, scheduler=None):
if optimizer is None:
optimizer = optim.AdamW(self.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay)
if scheduler is None:
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer,T_max = self.step_size, eta_min= self.eta_min)
return {
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": scheduler
},
}
代码
文本
[13]
wavenumber = [30, 30, 30, 30,30,30,30]
features_ = 20
padding = 9
dim_input = 1
with_grid = True
learning_rate = 1e-2
step_size = 20
gamma = 0.5
weight_decay = 1e-5
add_term = True
eta_min = 5e-4

model = FNO(wavenumber = wavenumber, features_ = features_,
padding = padding,
lifting = None,
proj = None,
dim_input = dim_input,
with_grid= with_grid,
learning_rate = learning_rate,
step_size = step_size,
gamma = gamma,
weight_decay = weight_decay,
add_term = add_term,
eta_min = eta_min)
代码
文本
[1]
max_epochs = 100
trainer = pl.Trainer(max_epochs=max_epochs,
accelerator= 'gpu',
devices = 1,
)
trainer.fit(model, train_loader, valid_loader)
---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
Cell In[1], line 2
      1 max_epochs = 100
----> 2 trainer = pl.Trainer(max_epochs=max_epochs,
      3                             accelerator= 'gpu', 
      4                             devices = 1,
      5                             )
      6 trainer.fit(model, train_loader, valid_loader)

NameError: name 'pl' is not defined
代码
文本
[ ]

代码
文本
[2]
!pip insatll pytorch_lightning
ERROR: unknown command "insatll" - maybe you meant "install"
代码
文本
Deep Learning
AI4S
Helmholtz
PDE
Deep LearningAI4SHelmholtzPDE
已赞6
本文被以下合集收录
PDE+AI
曾祉竣
更新于 2024-09-11
6 篇10 人关注
AI4S Cup - 超声CT成像中的声场预测
bohr40876c
更新于 2024-04-22
1 篇1 人关注
推荐阅读
公开
AI4SCUP-DFODE:DF-ODENet方法框架介绍
AI4S燃烧化学ODE积分反应速率预测AI4SCUP-DFODE
AI4S燃烧化学ODE积分反应速率预测AI4SCUP-DFODE
Archer
发布于 2024-01-19
7 赞8 转存文件1 评论
公开
DPA-2与化学反应过渡态搜索
过渡态DeePMD-kitASETS
过渡态DeePMD-kitASETS
李博文
发布于 2024-02-24
12 赞27 转存文件4 评论