Spaces:
Running
on
Zero
Running
on
Zero
HanzhouLiu
commited on
Commit
·
b56342d
1
Parent(s):
9f55394
Add application file
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- README.md +17 -7
- app.py +123 -0
- aug.py +56 -0
- config/xyscannetp_gopro/config_stage1.yaml +40 -0
- config/xyscannetp_gopro/config_stage2.yaml +41 -0
- config/xyscannetp_realj/config_stage2.yml +40 -0
- config/xyscannetp_realr/config_stage2.yml +40 -0
- dataset.py +140 -0
- datasets/datasets.txt +2 -0
- evaluate_NIQE.m +57 -0
- evaluate_RealBlur_J.py +117 -0
- evaluate_RealBlur_R.py +110 -0
- evaluation_GoPro.m +60 -0
- evaluation_HIDE.m +60 -0
- examples/blur1.png +3 -0
- examples/blur2.png +3 -0
- examples/blur3.png +3 -0
- examples/blur4.png +3 -0
- examples/blur5.png +3 -0
- license +37 -0
- metric_counter.py +55 -0
- models/XYScanNet.py +737 -0
- models/XYScanNetP.py +737 -0
- models/__init__.py +0 -0
- models/__pycache__/XYScanNet.cpython-38.pyc +0 -0
- models/__pycache__/XYScanNetP.cpython-38.pyc +0 -0
- models/__pycache__/__init__.cpython-38.pyc +0 -0
- models/__pycache__/networks.cpython-38.pyc +0 -0
- models/losses.py +233 -0
- models/models.py +36 -0
- models/networks.py +16 -0
- models/sota/FFTformer.py +324 -0
- models/sota/Restormer.py +340 -0
- models/sota/Stripformer.py +429 -0
- models/sota/XYScanNet.py +754 -0
- out/Results.txt +1 -0
- predict_GoPro_test_results.py +89 -0
- predict_HIDE_test_results.py +69 -0
- predict_RWBI_test_results.py +88 -0
- predict_RealBlur_J_test_results.py +97 -0
- predict_RealBlur_R_test_results.py +96 -0
- requirements.txt +11 -0
- results/xyscannetp_gopro/models/best_XYScanNet_stage2.pth +3 -0
- schedulers.py +59 -0
- train_XYScanNet_stage1.py +182 -0
- train_XYScanNet_stage2.py +182 -0
- util/__init__.py +0 -0
- util/__pycache__/__init__.cpython-310.pyc +0 -0
- util/__pycache__/__init__.cpython-36.pyc +0 -0
- util/__pycache__/__init__.cpython-38.pyc +0 -0
README.md
CHANGED
|
@@ -1,13 +1,23 @@
|
|
| 1 |
---
|
| 2 |
-
title: XYScanNet
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: gradio
|
| 7 |
-
sdk_version:
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
-
license: apache-2.0
|
| 11 |
---
|
| 12 |
|
| 13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: XYScanNet
|
| 3 |
+
emoji: 🚀
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: indigo
|
| 6 |
sdk: gradio
|
| 7 |
+
sdk_version: "4.44.1"
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
|
|
|
| 10 |
---
|
| 11 |
|
| 12 |
+
# XYScanNet: Mamba-based Image Deblurring Demo
|
| 13 |
+
|
| 14 |
+
This Space runs the **XYScanNet** deblurring model on GPU using **Gradio**.
|
| 15 |
+
Upload a blurry image, and the model will restore a sharp version automatically.
|
| 16 |
+
|
| 17 |
+
🧠 **Tech Highlights**
|
| 18 |
+
- Based on the **Mamba selective state space model**
|
| 19 |
+
- Implements cross-directional strip attention (horizontal & vertical)
|
| 20 |
+
- Runs efficiently on GPU with automatic padding
|
| 21 |
+
|
| 22 |
+
👤 Author: [Hanzhou Liu](https://huggingface.co/HanzhouLiu)
|
| 23 |
+
📦 Model weights: [HanzhouLiu/XYScanNet-weights](https://huggingface.co/spaces/HanzhouLiu/XYScanNet_Demo)
|
app.py
ADDED
|
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import spaces
|
| 3 |
+
import torch
|
| 4 |
+
import torchvision
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
from torch.autograd import Variable
|
| 7 |
+
import numpy as np
|
| 8 |
+
from PIL import Image
|
| 9 |
+
import yaml
|
| 10 |
+
import os
|
| 11 |
+
from models.networks import get_generator
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
# ===========================
|
| 15 |
+
# 1. Device setup
|
| 16 |
+
# ===========================
|
| 17 |
+
# Automatically choose GPU if available
|
| 18 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 19 |
+
print(f"🔥 Using device: {device}")
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
# ===========================
|
| 23 |
+
# 2. Model Loading
|
| 24 |
+
# ===========================
|
| 25 |
+
def load_model(job_name="xyscannetp_gopro"):
|
| 26 |
+
"""
|
| 27 |
+
Load the pretrained XYScanNet model on CPU or GPU automatically.
|
| 28 |
+
"""
|
| 29 |
+
cfg_path = os.path.join("config", job_name, "config_stage2.yaml")
|
| 30 |
+
with open(cfg_path, "r") as f:
|
| 31 |
+
config = yaml.safe_load(f)
|
| 32 |
+
|
| 33 |
+
weights_path = os.path.join(
|
| 34 |
+
"results", job_name, "models", f"best_{config['experiment_desc']}.pth"
|
| 35 |
+
)
|
| 36 |
+
print(f"🔹 Loading model from {weights_path}")
|
| 37 |
+
|
| 38 |
+
model = get_generator(config["model"])
|
| 39 |
+
model.load_state_dict(torch.load(weights_path, map_location=device))
|
| 40 |
+
model.eval().to(device)
|
| 41 |
+
|
| 42 |
+
print(f"✅ Model loaded on {device}")
|
| 43 |
+
return model
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
print("Initializing XYScanNet model...")
|
| 47 |
+
MODEL = load_model()
|
| 48 |
+
print("Model ready.")
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
# ===========================
|
| 52 |
+
# 3. Helper functions
|
| 53 |
+
# ===========================
|
| 54 |
+
def pad_to_multiple_of_8(img_tensor):
|
| 55 |
+
"""
|
| 56 |
+
Pad the image tensor so that both height and width are multiples of 8.
|
| 57 |
+
"""
|
| 58 |
+
_, _, h, w = img_tensor.shape
|
| 59 |
+
pad_h = (8 - h % 8) % 8
|
| 60 |
+
pad_w = (8 - w % 8) % 8
|
| 61 |
+
img_tensor = F.pad(img_tensor, (0, pad_w, 0, pad_h), mode="reflect")
|
| 62 |
+
return img_tensor, h, w
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def crop_back(img_tensor, orig_h, orig_w):
|
| 66 |
+
"""Crop output back to original image size."""
|
| 67 |
+
return img_tensor[:, :, :orig_h, :orig_w]
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
# ===========================
|
| 71 |
+
# 4. Inference Function
|
| 72 |
+
# ===========================
|
| 73 |
+
# The decorator below *requests* GPU if available,
|
| 74 |
+
# but won't crash if only CPU exists.
|
| 75 |
+
@spaces.GPU
|
| 76 |
+
def run_deblur(input_image: Image.Image):
|
| 77 |
+
"""
|
| 78 |
+
Run deblurring inference on GPU if available, else CPU.
|
| 79 |
+
"""
|
| 80 |
+
# Convert PIL RGB → Tensor [B,C,H,W] normalized to [-0.5,0.5]
|
| 81 |
+
img = np.array(input_image.convert("RGB"))
|
| 82 |
+
img_tensor = (
|
| 83 |
+
torch.from_numpy(np.transpose(img / 255.0, (2, 0, 1)).astype("float32")) - 0.5
|
| 84 |
+
)
|
| 85 |
+
img_tensor = Variable(img_tensor.unsqueeze(0)).to(device)
|
| 86 |
+
|
| 87 |
+
# Pad to valid window size
|
| 88 |
+
img_tensor, orig_h, orig_w = pad_to_multiple_of_8(img_tensor)
|
| 89 |
+
|
| 90 |
+
# Inference
|
| 91 |
+
with torch.no_grad():
|
| 92 |
+
result_image, _, _ = MODEL(img_tensor)
|
| 93 |
+
result_image = result_image + 0.5
|
| 94 |
+
result_image = crop_back(result_image, orig_h, orig_w)
|
| 95 |
+
|
| 96 |
+
# Convert to PIL Image for display
|
| 97 |
+
out_img = result_image.squeeze(0).clamp(0, 1).cpu()
|
| 98 |
+
out_pil = torchvision.transforms.ToPILImage()(out_img)
|
| 99 |
+
return out_pil
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
# ===========================
|
| 103 |
+
# 5. Gradio Interface
|
| 104 |
+
# ===========================
|
| 105 |
+
demo = gr.Interface(
|
| 106 |
+
fn=run_deblur,
|
| 107 |
+
inputs=gr.Image(type="pil", label="Upload a Blurry Image"),
|
| 108 |
+
outputs=gr.Image(type="pil", label="Deblurred Result"),
|
| 109 |
+
title="XYScanNet: Mamba-based Image Deblurring (GPU Demo)",
|
| 110 |
+
description=(
|
| 111 |
+
"Upload a blurry image to see how XYScanNet restores it using a Mamba-based vision state-space model."
|
| 112 |
+
),
|
| 113 |
+
examples=[
|
| 114 |
+
["examples/blur1.jpg"],
|
| 115 |
+
["examples/blur2.png"],
|
| 116 |
+
["examples/blur3.jpg"],
|
| 117 |
+
["examples/blur4.jpg"],
|
| 118 |
+
["examples/blur5.jpg"],
|
| 119 |
+
],
|
| 120 |
+
allow_flagging="never",
|
| 121 |
+
)
|
| 122 |
+
if __name__ == "__main__":
|
| 123 |
+
demo.launch()
|
aug.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List
|
| 2 |
+
|
| 3 |
+
import albumentations as albu
|
| 4 |
+
from torchvision import transforms
|
| 5 |
+
|
| 6 |
+
def get_transforms(size: int, scope: str = 'geometric', crop='random'):
|
| 7 |
+
augs = {'strong': albu.Compose([albu.HorizontalFlip(),
|
| 8 |
+
albu.ShiftScaleRotate(shift_limit=0.0, scale_limit=0.2, rotate_limit=20, p=.4),
|
| 9 |
+
albu.ElasticTransform(),
|
| 10 |
+
albu.OpticalDistortion(),
|
| 11 |
+
albu.OneOf([
|
| 12 |
+
albu.CLAHE(clip_limit=2),
|
| 13 |
+
albu.Sharpen(),
|
| 14 |
+
albu.Emboss(),
|
| 15 |
+
albu.RandomBrightnessContrast(),
|
| 16 |
+
albu.RandomGamma()
|
| 17 |
+
], p=0.5),
|
| 18 |
+
albu.OneOf([
|
| 19 |
+
albu.RGBShift(),
|
| 20 |
+
albu.HueSaturationValue(),
|
| 21 |
+
], p=0.5),
|
| 22 |
+
]),
|
| 23 |
+
'weak': albu.Compose([albu.HorizontalFlip(),
|
| 24 |
+
]),
|
| 25 |
+
'geometric': albu.Compose([albu.HorizontalFlip(),
|
| 26 |
+
albu.VerticalFlip(),
|
| 27 |
+
albu.RandomRotate90(),
|
| 28 |
+
]),
|
| 29 |
+
'None': None
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
aug_fn = augs[scope]
|
| 33 |
+
crop_fn = {'random': albu.RandomCrop(size, size, always_apply=True),
|
| 34 |
+
'center': albu.CenterCrop(size, size, always_apply=True)}[crop]
|
| 35 |
+
|
| 36 |
+
pipeline = albu.Compose([aug_fn, crop_fn], additional_targets={'target': 'image'})
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def process(a, b):
|
| 40 |
+
r = pipeline(image=a, target=b)
|
| 41 |
+
return r['image'], r['target']
|
| 42 |
+
|
| 43 |
+
return process
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def get_normalize():
|
| 47 |
+
transform = transforms.Compose([
|
| 48 |
+
transforms.ToTensor()
|
| 49 |
+
])
|
| 50 |
+
|
| 51 |
+
def process(a, b):
|
| 52 |
+
image = transform(a).permute(1, 2, 0) - 0.5
|
| 53 |
+
target = transform(b).permute(1, 2, 0) - 0.5
|
| 54 |
+
return image, target
|
| 55 |
+
|
| 56 |
+
return process
|
config/xyscannetp_gopro/config_stage1.yaml
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
experiment_desc: XYScanNet_stage1
|
| 3 |
+
|
| 4 |
+
train:
|
| 5 |
+
files_a: /scratch/user/hanzhou1996/datasets/deblur/GOPRO_/train/blur/**/*.png
|
| 6 |
+
files_b: /scratch/user/hanzhou1996/datasets/deblur/GOPRO_/train/sharp/**/*.png
|
| 7 |
+
size: &SIZE 252
|
| 8 |
+
crop: random
|
| 9 |
+
preload: &PRELOAD false
|
| 10 |
+
preload_size: &PRELOAD_SIZE 0
|
| 11 |
+
bounds: [0, 1]
|
| 12 |
+
scope: geometric
|
| 13 |
+
|
| 14 |
+
val:
|
| 15 |
+
files_a: /scratch/user/hanzhou1996/datasets/deblur/GOPRO_/test/blur/**/*.png
|
| 16 |
+
files_b: /scratch/user/hanzhou1996/datasets/deblur/GOPRO_/test/sharp/**/*.png
|
| 17 |
+
size: *SIZE
|
| 18 |
+
scope: None
|
| 19 |
+
crop: random
|
| 20 |
+
preload: *PRELOAD
|
| 21 |
+
preload_size: *PRELOAD_SIZE
|
| 22 |
+
bounds: [0, 1]
|
| 23 |
+
|
| 24 |
+
model:
|
| 25 |
+
g_name: XYScanNetP
|
| 26 |
+
content_loss: Stripformer_Loss
|
| 27 |
+
|
| 28 |
+
num_epochs: 4000
|
| 29 |
+
train_batches_per_epoch: 2103
|
| 30 |
+
val_batches_per_epoch: 1111
|
| 31 |
+
batch_size: 16
|
| 32 |
+
image_size: [252, 252]
|
| 33 |
+
|
| 34 |
+
optimizer:
|
| 35 |
+
name: adam
|
| 36 |
+
lr: 0.00022
|
| 37 |
+
scheduler:
|
| 38 |
+
name: cosine
|
| 39 |
+
start_epoch: 50
|
| 40 |
+
min_lr: 0.0000001
|
config/xyscannetp_gopro/config_stage2.yaml
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
experiment_desc: XYScanNet_stage2
|
| 3 |
+
|
| 4 |
+
train:
|
| 5 |
+
#/mnt/g/RESEARCH/PHD/Motion_Deblurred/datasets/GOPRO_
|
| 6 |
+
files_a: /scratch/user/hanzhou1996/datasets/deblur/GOPRO_/train/blur/**/*.png
|
| 7 |
+
files_b: /scratch/user/hanzhou1996/datasets/deblur/GOPRO_/train/sharp/**/*.png
|
| 8 |
+
size: &SIZE 320
|
| 9 |
+
crop: random
|
| 10 |
+
preload: &PRELOAD false
|
| 11 |
+
preload_size: &PRELOAD_SIZE 0
|
| 12 |
+
bounds: [0, 1]
|
| 13 |
+
scope: geometric
|
| 14 |
+
|
| 15 |
+
val:
|
| 16 |
+
files_a: /scratch/user/hanzhou1996/datasets/deblur/GOPRO_/test/blur/**/*.png
|
| 17 |
+
files_b: /scratch/user/hanzhou1996/datasets/deblur/GOPRO_/test/sharp/**/*.png
|
| 18 |
+
size: *SIZE
|
| 19 |
+
scope: None
|
| 20 |
+
crop: random
|
| 21 |
+
preload: *PRELOAD
|
| 22 |
+
preload_size: *PRELOAD_SIZE
|
| 23 |
+
bounds: [0, 1]
|
| 24 |
+
|
| 25 |
+
model:
|
| 26 |
+
g_name: XYScanNetP
|
| 27 |
+
content_loss: Stripformer_Loss
|
| 28 |
+
|
| 29 |
+
num_epochs: 4000
|
| 30 |
+
train_batches_per_epoch: 2103
|
| 31 |
+
val_batches_per_epoch: 1111
|
| 32 |
+
batch_size: 8
|
| 33 |
+
image_size: [320, 320]
|
| 34 |
+
|
| 35 |
+
optimizer:
|
| 36 |
+
name: adam
|
| 37 |
+
lr: 0.00015
|
| 38 |
+
scheduler:
|
| 39 |
+
name: cosine
|
| 40 |
+
start_epoch: 50
|
| 41 |
+
min_lr: 0.0000001
|
config/xyscannetp_realj/config_stage2.yml
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
experiment_desc: XYScanNet_stage2
|
| 3 |
+
|
| 4 |
+
train:
|
| 5 |
+
files_a: /scratch/user/hanzhou1996/datasets/deblur/RealBlur_J/train/trainA/*.png
|
| 6 |
+
files_b: /scratch/user/hanzhou1996/datasets/deblur/RealBlur_J/train/trainB/*.png
|
| 7 |
+
size: &SIZE 320
|
| 8 |
+
crop: random
|
| 9 |
+
preload: &PRELOAD false
|
| 10 |
+
preload_size: &PRELOAD_SIZE 0
|
| 11 |
+
bounds: [0, 1]
|
| 12 |
+
scope: geometric
|
| 13 |
+
|
| 14 |
+
val:
|
| 15 |
+
files_a: /scratch/user/hanzhou1996/datasets/deblur/RealBlur_J/test/testA/*.png
|
| 16 |
+
files_b: /scratch/user/hanzhou1996/datasets/deblur/RealBlur_J/test/testB/*.png
|
| 17 |
+
size: *SIZE
|
| 18 |
+
scope: None
|
| 19 |
+
crop: random
|
| 20 |
+
preload: *PRELOAD
|
| 21 |
+
preload_size: *PRELOAD_SIZE
|
| 22 |
+
bounds: [0, 1]
|
| 23 |
+
|
| 24 |
+
model:
|
| 25 |
+
g_name: XYScanNetP
|
| 26 |
+
content_loss: Stripformer_Loss
|
| 27 |
+
|
| 28 |
+
num_epochs: 2000
|
| 29 |
+
train_batches_per_epoch: 3758
|
| 30 |
+
val_batches_per_epoch: 980
|
| 31 |
+
batch_size: 8
|
| 32 |
+
image_size: [320, 320]
|
| 33 |
+
|
| 34 |
+
optimizer:
|
| 35 |
+
name: adam
|
| 36 |
+
lr: 0.0001
|
| 37 |
+
scheduler:
|
| 38 |
+
name: cosine
|
| 39 |
+
start_epoch: 50
|
| 40 |
+
min_lr: 0.0000001
|
config/xyscannetp_realr/config_stage2.yml
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
experiment_desc: XYScanNet_stage2
|
| 3 |
+
|
| 4 |
+
train:
|
| 5 |
+
files_a: /scratch/user/hanzhou1996/datasets/deblur/RealBlur_R/train/trainA/*.png
|
| 6 |
+
files_b: /scratch/user/hanzhou1996/datasets/deblur/RealBlur_R/train/trainB/*.png
|
| 7 |
+
size: &SIZE 320
|
| 8 |
+
crop: random
|
| 9 |
+
preload: &PRELOAD false
|
| 10 |
+
preload_size: &PRELOAD_SIZE 0
|
| 11 |
+
bounds: [0, 1]
|
| 12 |
+
scope: geometric
|
| 13 |
+
|
| 14 |
+
val:
|
| 15 |
+
files_a: /scratch/user/hanzhou1996/datasets/deblur/RealBlur_R/test/testA/*.png
|
| 16 |
+
files_b: /scratch/user/hanzhou1996/datasets/deblur/RealBlur_R/test/testB/*.png
|
| 17 |
+
size: *SIZE
|
| 18 |
+
scope: None
|
| 19 |
+
crop: random
|
| 20 |
+
preload: *PRELOAD
|
| 21 |
+
preload_size: *PRELOAD_SIZE
|
| 22 |
+
bounds: [0, 1]
|
| 23 |
+
|
| 24 |
+
model:
|
| 25 |
+
g_name: XYScanNetP
|
| 26 |
+
content_loss: Stripformer_Loss
|
| 27 |
+
|
| 28 |
+
num_epochs: 2000
|
| 29 |
+
train_batches_per_epoch: 3758
|
| 30 |
+
val_batches_per_epoch: 980
|
| 31 |
+
batch_size: 8
|
| 32 |
+
image_size: [320, 320]
|
| 33 |
+
|
| 34 |
+
optimizer:
|
| 35 |
+
name: adam
|
| 36 |
+
lr: 0.0001
|
| 37 |
+
scheduler:
|
| 38 |
+
name: cosine
|
| 39 |
+
start_epoch: 50
|
| 40 |
+
min_lr: 0.0000001
|
dataset.py
ADDED
|
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from copy import deepcopy
|
| 3 |
+
from functools import partial
|
| 4 |
+
from glob import glob
|
| 5 |
+
from hashlib import sha1
|
| 6 |
+
from typing import Callable, Iterable, Optional, Tuple
|
| 7 |
+
|
| 8 |
+
import cv2
|
| 9 |
+
import numpy as np
|
| 10 |
+
from glog import logger
|
| 11 |
+
from joblib import Parallel, cpu_count, delayed
|
| 12 |
+
from skimage.io import imread
|
| 13 |
+
from torch.utils.data import Dataset
|
| 14 |
+
from tqdm import tqdm
|
| 15 |
+
|
| 16 |
+
import aug
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def subsample(data: Iterable, bounds: Tuple[float, float], hash_fn: Callable, n_buckets=100, salt='', verbose=True):
|
| 20 |
+
data = list(data)
|
| 21 |
+
buckets = split_into_buckets(data, n_buckets=n_buckets, salt=salt, hash_fn=hash_fn)
|
| 22 |
+
|
| 23 |
+
lower_bound, upper_bound = [x * n_buckets for x in bounds]
|
| 24 |
+
msg = f'Subsampling buckets from {lower_bound} to {upper_bound}, total buckets number is {n_buckets}'
|
| 25 |
+
if salt:
|
| 26 |
+
msg += f'; salt is {salt}'
|
| 27 |
+
if verbose:
|
| 28 |
+
logger.info(msg)
|
| 29 |
+
return np.array([sample for bucket, sample in zip(buckets, data) if lower_bound <= bucket < upper_bound])
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def hash_from_paths(x: Tuple[str, str], salt: str = '') -> str:
|
| 33 |
+
path_a, path_b = x
|
| 34 |
+
names = ''.join(map(os.path.basename, (path_a, path_b)))
|
| 35 |
+
return sha1(f'{names}_{salt}'.encode()).hexdigest()
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def split_into_buckets(data: Iterable, n_buckets: int, hash_fn: Callable, salt=''):
|
| 39 |
+
hashes = map(partial(hash_fn, salt=salt), data)
|
| 40 |
+
return np.array([int(x, 16) % n_buckets for x in hashes])
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def _read_img(x: str):
|
| 44 |
+
img = cv2.imread(x)
|
| 45 |
+
if img is None:
|
| 46 |
+
logger.warning(f'Can not read image {x} with OpenCV, switching to scikit-image')
|
| 47 |
+
img = imread(x)
|
| 48 |
+
return img
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class PairedDataset(Dataset):
|
| 52 |
+
def __init__(self,
|
| 53 |
+
files_a: Tuple[str],
|
| 54 |
+
files_b: Tuple[str],
|
| 55 |
+
transform_fn: Callable,
|
| 56 |
+
normalize_fn: Callable,
|
| 57 |
+
corrupt_fn: Optional[Callable] = None,
|
| 58 |
+
preload: bool = True,
|
| 59 |
+
preload_size: Optional[int] = 0,
|
| 60 |
+
verbose=True):
|
| 61 |
+
|
| 62 |
+
assert len(files_a) == len(files_b)
|
| 63 |
+
|
| 64 |
+
self.preload = preload
|
| 65 |
+
self.data_a = files_a
|
| 66 |
+
self.data_b = files_b
|
| 67 |
+
self.verbose = verbose
|
| 68 |
+
self.corrupt_fn = corrupt_fn
|
| 69 |
+
self.transform_fn = transform_fn
|
| 70 |
+
self.normalize_fn = normalize_fn
|
| 71 |
+
logger.info(f'Dataset has been created with {len(self.data_a)} samples')
|
| 72 |
+
|
| 73 |
+
if preload:
|
| 74 |
+
preload_fn = partial(self._bulk_preload, preload_size=preload_size)
|
| 75 |
+
if files_a == files_b:
|
| 76 |
+
self.data_a = self.data_b = preload_fn(self.data_a)
|
| 77 |
+
else:
|
| 78 |
+
self.data_a, self.data_b = map(preload_fn, (self.data_a, self.data_b))
|
| 79 |
+
self.preload = True
|
| 80 |
+
|
| 81 |
+
def _bulk_preload(self, data: Iterable[str], preload_size: int):
|
| 82 |
+
jobs = [delayed(self._preload)(x, preload_size=preload_size) for x in data]
|
| 83 |
+
jobs = tqdm(jobs, desc='preloading images', disable=not self.verbose)
|
| 84 |
+
return Parallel(n_jobs=cpu_count(), backend='threading')(jobs)
|
| 85 |
+
|
| 86 |
+
@staticmethod
|
| 87 |
+
def _preload(x: str, preload_size: int):
|
| 88 |
+
img = _read_img(x)
|
| 89 |
+
if preload_size:
|
| 90 |
+
h, w, *_ = img.shape
|
| 91 |
+
h_scale = preload_size / h
|
| 92 |
+
w_scale = preload_size / w
|
| 93 |
+
scale = max(h_scale, w_scale)
|
| 94 |
+
img = cv2.resize(img, fx=scale, fy=scale, dsize=None)
|
| 95 |
+
assert min(img.shape[:2]) >= preload_size, f'weird img shape: {img.shape}'
|
| 96 |
+
return img
|
| 97 |
+
|
| 98 |
+
def _preprocess(self, img, res):
|
| 99 |
+
def transpose(x):
|
| 100 |
+
return np.transpose(x, (2, 0, 1))
|
| 101 |
+
|
| 102 |
+
return map(transpose, self.normalize_fn(img, res))
|
| 103 |
+
|
| 104 |
+
def __len__(self):
|
| 105 |
+
return len(self.data_a)
|
| 106 |
+
|
| 107 |
+
def __getitem__(self, idx):
|
| 108 |
+
a, b = self.data_a[idx], self.data_b[idx]
|
| 109 |
+
if not self.preload:
|
| 110 |
+
a, b = map(_read_img, (a, b))
|
| 111 |
+
a, b = self.transform_fn(a, b)
|
| 112 |
+
if self.corrupt_fn is not None:
|
| 113 |
+
a = self.corrupt_fn(a)
|
| 114 |
+
a, b = self._preprocess(a, b)
|
| 115 |
+
return {'a': a, 'b': b}
|
| 116 |
+
|
| 117 |
+
@staticmethod
|
| 118 |
+
def from_config(config):
|
| 119 |
+
config = deepcopy(config)
|
| 120 |
+
files_a, files_b = map(lambda x: sorted(glob(config[x], recursive=True)), ('files_a', 'files_b'))
|
| 121 |
+
transform_fn = aug.get_transforms(size=config['size'], scope=config['scope'], crop=config['crop'])
|
| 122 |
+
normalize_fn = aug.get_normalize()
|
| 123 |
+
|
| 124 |
+
hash_fn = hash_from_paths
|
| 125 |
+
# ToDo: add more hash functions
|
| 126 |
+
verbose = config.get('verbose', True)
|
| 127 |
+
data = subsample(data=zip(files_a, files_b),
|
| 128 |
+
bounds=config.get('bounds', (0, 1)),
|
| 129 |
+
hash_fn=hash_fn,
|
| 130 |
+
verbose=verbose)
|
| 131 |
+
|
| 132 |
+
files_a, files_b = map(list, zip(*data))
|
| 133 |
+
|
| 134 |
+
return PairedDataset(files_a=files_a,
|
| 135 |
+
files_b=files_b,
|
| 136 |
+
preload=config['preload'],
|
| 137 |
+
preload_size=config['preload_size'],
|
| 138 |
+
normalize_fn=normalize_fn,
|
| 139 |
+
transform_fn=transform_fn,
|
| 140 |
+
verbose=verbose)
|
datasets/datasets.txt
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
A good way to maintain the datasets in a project is to create a soft link.
|
| 2 |
+
In that case, you simply set the dataset path to the current path in the config files.
|
evaluate_NIQE.m
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
%p = genpath("G:\RESEARCH\PHD\Motion_Deblurred\xyscannet\visualization_bidir\comparison\epoch3k\uni");
|
| 2 |
+
%gt = genpath("G:\RESEARCH\PHD\Motion_Deblurred\xyscannet\visualization_bidir\comparison\epoch3k\gt");
|
| 3 |
+
|
| 4 |
+
length_p = size(p,2);
|
| 5 |
+
path = {};
|
| 6 |
+
temp = [];
|
| 7 |
+
for i = 1:length_p
|
| 8 |
+
if p(i) ~= ';'
|
| 9 |
+
temp = [temp p(i)];
|
| 10 |
+
else
|
| 11 |
+
temp = [temp '\'];
|
| 12 |
+
path = [path ; temp];
|
| 13 |
+
temp = [];
|
| 14 |
+
end
|
| 15 |
+
end
|
| 16 |
+
clear p length_p temp;
|
| 17 |
+
length_gt = size(gt,2);
|
| 18 |
+
path_gt = {};
|
| 19 |
+
temp_gt = [];
|
| 20 |
+
for i = 1:length_gt
|
| 21 |
+
if gt(i) ~= ';'
|
| 22 |
+
temp_gt = [temp_gt gt(i)];
|
| 23 |
+
else
|
| 24 |
+
temp_gt = [temp_gt '\'];
|
| 25 |
+
path_gt = [path_gt ; temp_gt];
|
| 26 |
+
temp_gt = [];
|
| 27 |
+
end
|
| 28 |
+
end
|
| 29 |
+
clear gt length_gt temp_gt;
|
| 30 |
+
|
| 31 |
+
file_num = size(path,1);
|
| 32 |
+
total_niqe = 0;
|
| 33 |
+
n = 0;
|
| 34 |
+
for i = 1:file_num
|
| 35 |
+
file_path = path{i};
|
| 36 |
+
gt_file_path = path_gt{i};
|
| 37 |
+
img_path_list = dir(strcat(file_path,'*.png'));
|
| 38 |
+
gt_path_list = dir(strcat(gt_file_path,'*.png'));
|
| 39 |
+
img_num = length(img_path_list);
|
| 40 |
+
if img_num > 0
|
| 41 |
+
for j = 1:img_num
|
| 42 |
+
image_name = img_path_list(j).name;
|
| 43 |
+
gt_name = gt_path_list(j).name;
|
| 44 |
+
image = imread(strcat(file_path,image_name));
|
| 45 |
+
gt = imread(strcat(gt_file_path,gt_name));
|
| 46 |
+
size(image);
|
| 47 |
+
size(gt);
|
| 48 |
+
cur_niqe = niqe(image);
|
| 49 |
+
fprintf('%d', cur_niqe);
|
| 50 |
+
total_niqe = total_niqe + cur_niqe;
|
| 51 |
+
n = n + 1
|
| 52 |
+
end
|
| 53 |
+
end
|
| 54 |
+
end
|
| 55 |
+
niqe_score = total_niqe / n
|
| 56 |
+
close all;clear all;
|
| 57 |
+
|
evaluate_RealBlur_J.py
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from skimage import io
|
| 3 |
+
import cv2
|
| 4 |
+
import numpy as np
|
| 5 |
+
from skimage.metrics import structural_similarity
|
| 6 |
+
import concurrent.futures
|
| 7 |
+
|
| 8 |
+
def image_align(deblurred, gt):
|
| 9 |
+
# this function is based on kohler evaluation code
|
| 10 |
+
z = deblurred
|
| 11 |
+
c = np.ones_like(z)
|
| 12 |
+
x = gt
|
| 13 |
+
|
| 14 |
+
zs = (np.sum(x * z) / np.sum(z * z)) * z # simple intensity matching
|
| 15 |
+
|
| 16 |
+
warp_mode = cv2.MOTION_HOMOGRAPHY
|
| 17 |
+
warp_matrix = np.eye(3, 3, dtype=np.float32)
|
| 18 |
+
|
| 19 |
+
# Specify the number of iterations.
|
| 20 |
+
number_of_iterations = 100
|
| 21 |
+
|
| 22 |
+
termination_eps = 0
|
| 23 |
+
|
| 24 |
+
criteria = (cv2.TERM_CRITERIA_EPS | cv2.TERM_CRITERIA_COUNT,
|
| 25 |
+
number_of_iterations, termination_eps)
|
| 26 |
+
|
| 27 |
+
# Run the ECC algorithm. The results are stored in warp_matrix.
|
| 28 |
+
(cc, warp_matrix) = cv2.findTransformECC(cv2.cvtColor(x, cv2.COLOR_RGB2GRAY), cv2.cvtColor(zs, cv2.COLOR_RGB2GRAY),
|
| 29 |
+
warp_matrix, warp_mode, criteria, inputMask=None, gaussFiltSize=5)
|
| 30 |
+
|
| 31 |
+
target_shape = x.shape
|
| 32 |
+
shift = warp_matrix
|
| 33 |
+
|
| 34 |
+
zr = cv2.warpPerspective(
|
| 35 |
+
zs,
|
| 36 |
+
warp_matrix,
|
| 37 |
+
(target_shape[1], target_shape[0]),
|
| 38 |
+
flags=cv2.INTER_CUBIC + cv2.WARP_INVERSE_MAP,
|
| 39 |
+
borderMode=cv2.BORDER_REFLECT)
|
| 40 |
+
|
| 41 |
+
cr = cv2.warpPerspective(
|
| 42 |
+
np.ones_like(zs, dtype='float32'),
|
| 43 |
+
warp_matrix,
|
| 44 |
+
(target_shape[1], target_shape[0]),
|
| 45 |
+
flags=cv2.INTER_NEAREST + cv2.WARP_INVERSE_MAP,
|
| 46 |
+
borderMode=cv2.BORDER_CONSTANT,
|
| 47 |
+
borderValue=0)
|
| 48 |
+
|
| 49 |
+
zr = zr * cr
|
| 50 |
+
xr = x * cr
|
| 51 |
+
|
| 52 |
+
return zr, xr, cr, shift
|
| 53 |
+
|
| 54 |
+
def compute_psnr(image_true, image_test, image_mask, data_range=None):
|
| 55 |
+
# this function is based on skimage.metrics.peak_signal_noise_ratio
|
| 56 |
+
err = np.sum((image_true - image_test) ** 2, dtype=np.float64) / np.sum(image_mask)
|
| 57 |
+
return 10 * np.log10((data_range ** 2) / err)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def compute_ssim(tar_img, prd_img, cr1):
|
| 61 |
+
ssim_pre, ssim_map = structural_similarity(tar_img, prd_img, channel_axis=2, gaussian_weights=True,
|
| 62 |
+
use_sample_covariance=False, data_range=1.0, full=True)
|
| 63 |
+
ssim_map = ssim_map * cr1
|
| 64 |
+
r = int(3.5 * 1.5 + 0.5) # radius as in ndimage
|
| 65 |
+
win_size = 2 * r + 1
|
| 66 |
+
pad = (win_size - 1) // 2
|
| 67 |
+
ssim = ssim_map[pad:-pad, pad:-pad, :]
|
| 68 |
+
crop_cr1 = cr1[pad:-pad, pad:-pad, :]
|
| 69 |
+
ssim = ssim.sum(axis=0).sum(axis=0) / crop_cr1.sum(axis=0).sum(axis=0)
|
| 70 |
+
ssim = np.mean(ssim)
|
| 71 |
+
return ssim
|
| 72 |
+
|
| 73 |
+
total_psnr = 0.
|
| 74 |
+
total_ssim = 0.
|
| 75 |
+
count = 0
|
| 76 |
+
#img_path = '/mnt/g/RESEARCH/PHD/Motion_Deblurred/xyscannet/ablation/v33/run1/images_realj'
|
| 77 |
+
#img_path = '/mnt/g/RESEARCH/PHD/Motion_Deblurred/xyscannet/sota/algnet/images/RealBlur_J'
|
| 78 |
+
#img_path = '/mnt/g/RESEARCH/PHD/Motion_Deblurred/xyscannet/sota/deeprft/FMIMOUNetPLUS_RealBlur/RealBlur_J--'
|
| 79 |
+
#img_path = '/mnt/g/RESEARCH/PHD/Motion_Deblurred/xyscannet/sota/deeprft/images_author/RealBlur_J'
|
| 80 |
+
#img_path = '/mnt/g/RESEARCH/PHD/Motion_Deblurred/xyscannet/sota/mprnet/images/RealBlur_J'
|
| 81 |
+
#img_path = '/mnt/g/RESEARCH/PHD/Motion_Deblurred/xyscannet/sota/stripformer/images/RealBlur_J'
|
| 82 |
+
img_path = '/mnt/g/RESEARCH/PHD/Motion_Deblurred/xyscannet/sota/xyscannetp/images/realj_final_stage3'
|
| 83 |
+
|
| 84 |
+
gt_path = '/mnt/g/RESEARCH/PHD/Motion_Deblurred/datasets/Realblur_J/test/testB'
|
| 85 |
+
|
| 86 |
+
print(img_path)
|
| 87 |
+
for file in os.listdir(img_path):
|
| 88 |
+
#for img_name in os.listdir(img_path + '/' + file):
|
| 89 |
+
img_name = file
|
| 90 |
+
count += 1
|
| 91 |
+
number = img_name.split('_')[1]
|
| 92 |
+
#number = img_name.split('-')[1]
|
| 93 |
+
#gt_name = 'gt_' + number
|
| 94 |
+
img_dir = img_path + '/' + file
|
| 95 |
+
s = file.split('_')
|
| 96 |
+
#s = file.split('-')
|
| 97 |
+
#gt_file = s[0] + '_' + 'gt_' + number
|
| 98 |
+
gt_file = '_'.join([s[0], 'gt', s[-1]])
|
| 99 |
+
gt_dir = gt_path + '/' + gt_file
|
| 100 |
+
print(gt_file)
|
| 101 |
+
print(img_dir)
|
| 102 |
+
with concurrent.futures.ProcessPoolExecutor(max_workers=10) as executor:
|
| 103 |
+
tar_img = io.imread(gt_dir)
|
| 104 |
+
prd_img = io.imread(img_dir)
|
| 105 |
+
tar_img = tar_img.astype(np.float32) / 255.0
|
| 106 |
+
prd_img = prd_img.astype(np.float32) / 255.0
|
| 107 |
+
prd_img, tar_img, cr1, shift = image_align(prd_img, tar_img)
|
| 108 |
+
PSNR = compute_psnr(tar_img, prd_img, cr1, data_range=1)
|
| 109 |
+
SSIM = compute_ssim(tar_img, prd_img, cr1)
|
| 110 |
+
total_psnr += PSNR
|
| 111 |
+
total_ssim += SSIM
|
| 112 |
+
print(count, PSNR)
|
| 113 |
+
|
| 114 |
+
print('PSNR:', total_psnr / count)
|
| 115 |
+
print('SSIM:', total_ssim / count)
|
| 116 |
+
print(img_path)
|
| 117 |
+
|
evaluate_RealBlur_R.py
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from skimage import io
|
| 3 |
+
import cv2
|
| 4 |
+
import numpy as np
|
| 5 |
+
from skimage.metrics import structural_similarity
|
| 6 |
+
import concurrent.futures
|
| 7 |
+
|
| 8 |
+
def image_align(deblurred, gt):
|
| 9 |
+
# this function is based on kohler evaluation code
|
| 10 |
+
z = deblurred
|
| 11 |
+
c = np.ones_like(z)
|
| 12 |
+
x = gt
|
| 13 |
+
|
| 14 |
+
zs = (np.sum(x * z) / np.sum(z * z)) * z # simple intensity matching
|
| 15 |
+
|
| 16 |
+
warp_mode = cv2.MOTION_HOMOGRAPHY
|
| 17 |
+
warp_matrix = np.eye(3, 3, dtype=np.float32)
|
| 18 |
+
|
| 19 |
+
# Specify the number of iterations.
|
| 20 |
+
number_of_iterations = 100
|
| 21 |
+
|
| 22 |
+
termination_eps = 0
|
| 23 |
+
|
| 24 |
+
criteria = (cv2.TERM_CRITERIA_EPS | cv2.TERM_CRITERIA_COUNT,
|
| 25 |
+
number_of_iterations, termination_eps)
|
| 26 |
+
|
| 27 |
+
# Run the ECC algorithm. The results are stored in warp_matrix.
|
| 28 |
+
(cc, warp_matrix) = cv2.findTransformECC(cv2.cvtColor(x, cv2.COLOR_RGB2GRAY), cv2.cvtColor(zs, cv2.COLOR_RGB2GRAY),
|
| 29 |
+
warp_matrix, warp_mode, criteria, inputMask=None, gaussFiltSize=5)
|
| 30 |
+
|
| 31 |
+
target_shape = x.shape
|
| 32 |
+
shift = warp_matrix
|
| 33 |
+
|
| 34 |
+
zr = cv2.warpPerspective(
|
| 35 |
+
zs,
|
| 36 |
+
warp_matrix,
|
| 37 |
+
(target_shape[1], target_shape[0]),
|
| 38 |
+
flags=cv2.INTER_CUBIC + cv2.WARP_INVERSE_MAP,
|
| 39 |
+
borderMode=cv2.BORDER_REFLECT)
|
| 40 |
+
|
| 41 |
+
cr = cv2.warpPerspective(
|
| 42 |
+
np.ones_like(zs, dtype='float32'),
|
| 43 |
+
warp_matrix,
|
| 44 |
+
(target_shape[1], target_shape[0]),
|
| 45 |
+
flags=cv2.INTER_NEAREST + cv2.WARP_INVERSE_MAP,
|
| 46 |
+
borderMode=cv2.BORDER_CONSTANT,
|
| 47 |
+
borderValue=0)
|
| 48 |
+
|
| 49 |
+
zr = zr * cr
|
| 50 |
+
xr = x * cr
|
| 51 |
+
|
| 52 |
+
return zr, xr, cr, shift
|
| 53 |
+
|
| 54 |
+
def compute_psnr(image_true, image_test, image_mask, data_range=None):
|
| 55 |
+
# this function is based on skimage.metrics.peak_signal_noise_ratio
|
| 56 |
+
err = np.sum((image_true - image_test) ** 2, dtype=np.float64) / np.sum(image_mask)
|
| 57 |
+
return 10 * np.log10((data_range ** 2) / err)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def compute_ssim(tar_img, prd_img, cr1):
|
| 61 |
+
ssim_pre, ssim_map = structural_similarity(tar_img, prd_img, channel_axis=2, gaussian_weights=True,
|
| 62 |
+
use_sample_covariance=False, data_range=1.0, full=True)
|
| 63 |
+
ssim_map = ssim_map * cr1
|
| 64 |
+
r = int(3.5 * 1.5 + 0.5) # radius as in ndimage
|
| 65 |
+
win_size = 2 * r + 1
|
| 66 |
+
pad = (win_size - 1) // 2
|
| 67 |
+
ssim = ssim_map[pad:-pad, pad:-pad, :]
|
| 68 |
+
crop_cr1 = cr1[pad:-pad, pad:-pad, :]
|
| 69 |
+
ssim = ssim.sum(axis=0).sum(axis=0) / crop_cr1.sum(axis=0).sum(axis=0)
|
| 70 |
+
ssim = np.mean(ssim)
|
| 71 |
+
return ssim
|
| 72 |
+
|
| 73 |
+
total_psnr = 0.
|
| 74 |
+
total_ssim = 0.
|
| 75 |
+
count = 0
|
| 76 |
+
#gt_path = '/mnt/g/RESEARCH/PHD/Motion_Deblurred/datasets/Realblur_R/test/testB'
|
| 77 |
+
img_path = '/mnt/g/RESEARCH/PHD/Motion_Deblurred/xyscannet/sota/deeprft/FMIMOUNetPLUS_RealBlur/RealBlur_R__'
|
| 78 |
+
#img_path = '/mnt/g/RESEARCH/PHD/Motion_Deblurred/xyscannet/sota/stripformer/images/RealBlur_R'
|
| 79 |
+
#img_path = '/mnt/g/RESEARCH/PHD/Motion_Deblurred/xyscannet/sota/deeprft/images/RealBlur_R'
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
gt_path = '/mnt/g/RESEARCH/PHD/Motion_Deblurred/datasets/Realblur_R/test/testB'
|
| 83 |
+
print(img_path)
|
| 84 |
+
for file in os.listdir(img_path):
|
| 85 |
+
#for img_name in os.listdir(img_path + '/' + file):
|
| 86 |
+
img_name = file
|
| 87 |
+
count += 1
|
| 88 |
+
number = img_name.split('_')[1]
|
| 89 |
+
#gt_name = 'gt_' + number
|
| 90 |
+
img_dir = img_path + '/' + file
|
| 91 |
+
s = file.split('_')
|
| 92 |
+
gt_file = '_'.join([s[0], 'gt', s[-1]])
|
| 93 |
+
gt_dir = gt_path + '/' + gt_file
|
| 94 |
+
print(img_dir)
|
| 95 |
+
with concurrent.futures.ProcessPoolExecutor(max_workers=10) as executor:
|
| 96 |
+
tar_img = io.imread(gt_dir)
|
| 97 |
+
prd_img = io.imread(img_dir)
|
| 98 |
+
tar_img = tar_img.astype(np.float32) / 255.0
|
| 99 |
+
prd_img = prd_img.astype(np.float32) / 255.0
|
| 100 |
+
prd_img, tar_img, cr1, shift = image_align(prd_img, tar_img)
|
| 101 |
+
PSNR = compute_psnr(tar_img, prd_img, cr1, data_range=1)
|
| 102 |
+
SSIM = compute_ssim(tar_img, prd_img, cr1)
|
| 103 |
+
total_psnr += PSNR
|
| 104 |
+
total_ssim += SSIM
|
| 105 |
+
print(count, PSNR)
|
| 106 |
+
|
| 107 |
+
print('PSNR:', total_psnr / count)
|
| 108 |
+
print('SSIM:', total_ssim / count)
|
| 109 |
+
print(img_path)
|
| 110 |
+
|
evaluation_GoPro.m
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
p = genpath('.\out\Stripformer_GoPro_results');% GoPro Deblur Results
|
| 2 |
+
gt = genpath('.\datasets\GoPro\test\sharp');% GoPro GT Results
|
| 3 |
+
|
| 4 |
+
length_p = size(p,2);
|
| 5 |
+
path = {};
|
| 6 |
+
temp = [];
|
| 7 |
+
for i = 1:length_p
|
| 8 |
+
if p(i) ~= ';'
|
| 9 |
+
temp = [temp p(i)];
|
| 10 |
+
else
|
| 11 |
+
temp = [temp '\'];
|
| 12 |
+
path = [path ; temp];
|
| 13 |
+
temp = [];
|
| 14 |
+
end
|
| 15 |
+
end
|
| 16 |
+
clear p length_p temp;
|
| 17 |
+
length_gt = size(gt,2);
|
| 18 |
+
path_gt = {};
|
| 19 |
+
temp_gt = [];
|
| 20 |
+
for i = 1:length_gt
|
| 21 |
+
if gt(i) ~= ';'
|
| 22 |
+
temp_gt = [temp_gt gt(i)];
|
| 23 |
+
else
|
| 24 |
+
temp_gt = [temp_gt '\'];
|
| 25 |
+
path_gt = [path_gt ; temp_gt];
|
| 26 |
+
temp_gt = [];
|
| 27 |
+
end
|
| 28 |
+
end
|
| 29 |
+
clear gt length_gt temp_gt;
|
| 30 |
+
|
| 31 |
+
file_num = size(path,1);
|
| 32 |
+
total_psnr = 0;
|
| 33 |
+
n = 0;
|
| 34 |
+
total_ssim = 0;
|
| 35 |
+
for i = 1:file_num
|
| 36 |
+
file_path = path{i};
|
| 37 |
+
gt_file_path = path_gt{i};
|
| 38 |
+
img_path_list = dir(strcat(file_path,'*.png'));
|
| 39 |
+
gt_path_list = dir(strcat(gt_file_path,'*.png'));
|
| 40 |
+
img_num = length(img_path_list);
|
| 41 |
+
if img_num > 0
|
| 42 |
+
for j = 1:img_num
|
| 43 |
+
image_name = img_path_list(j).name;
|
| 44 |
+
gt_name = gt_path_list(j).name;
|
| 45 |
+
image = imread(strcat(file_path,image_name));
|
| 46 |
+
gt = imread(strcat(gt_file_path,gt_name));
|
| 47 |
+
size(image);
|
| 48 |
+
size(gt);
|
| 49 |
+
peaksnr = psnr(image,gt);
|
| 50 |
+
ssimval = ssim(image,gt);
|
| 51 |
+
total_psnr = total_psnr + peaksnr;
|
| 52 |
+
total_ssim = total_ssim + ssimval;
|
| 53 |
+
n = n + 1
|
| 54 |
+
end
|
| 55 |
+
end
|
| 56 |
+
end
|
| 57 |
+
psnr = total_psnr / n
|
| 58 |
+
ssim = total_ssim / n
|
| 59 |
+
close all;clear all;
|
| 60 |
+
|
evaluation_HIDE.m
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
p = genpath('.\out\Stripformer_HIDE_results');% HIDE Deblur Results
|
| 2 |
+
gt = genpath('.\datasets\HIDE\sharp');% HIDE GT Results
|
| 3 |
+
|
| 4 |
+
length_p = size(p,2);
|
| 5 |
+
path = {};
|
| 6 |
+
temp = [];
|
| 7 |
+
for i = 1:length_p
|
| 8 |
+
if p(i) ~= ';'
|
| 9 |
+
temp = [temp p(i)];
|
| 10 |
+
else
|
| 11 |
+
temp = [temp '\'];
|
| 12 |
+
path = [path ; temp];
|
| 13 |
+
temp = [];
|
| 14 |
+
end
|
| 15 |
+
end
|
| 16 |
+
clear p length_p temp;
|
| 17 |
+
length_gt = size(gt,2);
|
| 18 |
+
path_gt = {};
|
| 19 |
+
temp_gt = [];
|
| 20 |
+
for i = 1:length_gt
|
| 21 |
+
if gt(i) ~= ';'
|
| 22 |
+
temp_gt = [temp_gt gt(i)];
|
| 23 |
+
else
|
| 24 |
+
temp_gt = [temp_gt '\'];
|
| 25 |
+
path_gt = [path_gt ; temp_gt];
|
| 26 |
+
temp_gt = [];
|
| 27 |
+
end
|
| 28 |
+
end
|
| 29 |
+
clear gt length_gt temp_gt;
|
| 30 |
+
|
| 31 |
+
file_num = size(path,1);
|
| 32 |
+
total_psnr = 0;
|
| 33 |
+
n = 0;
|
| 34 |
+
total_ssim = 0;
|
| 35 |
+
for i = 1:file_num
|
| 36 |
+
file_path = path{i};
|
| 37 |
+
gt_file_path = path_gt{i};
|
| 38 |
+
img_path_list = dir(strcat(file_path,'*.png'));
|
| 39 |
+
gt_path_list = dir(strcat(gt_file_path,'*.png'));
|
| 40 |
+
img_num = length(img_path_list);
|
| 41 |
+
if img_num > 0
|
| 42 |
+
for j = 1:img_num
|
| 43 |
+
image_name = img_path_list(j).name;
|
| 44 |
+
gt_name = gt_path_list(j).name;
|
| 45 |
+
image = imread(strcat(file_path,image_name));
|
| 46 |
+
gt = imread(strcat(gt_file_path,gt_name));
|
| 47 |
+
size(image);
|
| 48 |
+
size(gt);
|
| 49 |
+
peaksnr = psnr(image,gt);
|
| 50 |
+
ssimval = ssim(image,gt);
|
| 51 |
+
total_psnr = total_psnr + peaksnr;
|
| 52 |
+
total_ssim = total_ssim + ssimval;
|
| 53 |
+
n = n + 1
|
| 54 |
+
end
|
| 55 |
+
end
|
| 56 |
+
end
|
| 57 |
+
psnr = total_psnr / n
|
| 58 |
+
ssim = total_ssim / n
|
| 59 |
+
close all;clear all;
|
| 60 |
+
|
examples/blur1.png
ADDED
|
Git LFS Details
|
examples/blur2.png
ADDED
|
Git LFS Details
|
examples/blur3.png
ADDED
|
Git LFS Details
|
examples/blur4.png
ADDED
|
Git LFS Details
|
examples/blur5.png
ADDED
|
Git LFS Details
|
license
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
Copyright (c) 2025 Hanzhou Liu
|
| 3 |
+
|
| 4 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 5 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 6 |
+
in the Software with non-commercial usage, including non-commercial usage
|
| 7 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 8 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 9 |
+
furnished to do so, subject to the following conditions:
|
| 10 |
+
|
| 11 |
+
The above copyright notice and this permission notice shall be included in all
|
| 12 |
+
copies or substantial portions of the Software.
|
| 13 |
+
|
| 14 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 15 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 16 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 17 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 18 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 19 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 20 |
+
SOFTWARE.
|
| 21 |
+
|
| 22 |
+
--------------------------- LICENSE FOR DeblurGANv2 --------------------------------
|
| 23 |
+
BSD License
|
| 24 |
+
|
| 25 |
+
For DeblurGANv2 software
|
| 26 |
+
Copyright (c) 2019, Orest Kupyn, Tetiana Martyniuk, Junru Wu and Zhangyang Wang
|
| 27 |
+
All rights reserved.
|
| 28 |
+
|
| 29 |
+
Redistribution and use in source and binary forms, with or without
|
| 30 |
+
modification, are permitted provided that the following conditions are met:
|
| 31 |
+
|
| 32 |
+
* Redistributions of source code must retain the above copyright notice, this
|
| 33 |
+
list of conditions and the following disclaimer.
|
| 34 |
+
|
| 35 |
+
* Redistributions in binary form must reproduce the above copyright notice,
|
| 36 |
+
this list of conditions and the following disclaimer in the documentation
|
| 37 |
+
and/or other materials provided with the distribution.
|
metric_counter.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from collections import defaultdict
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
from tensorboardX import SummaryWriter
|
| 6 |
+
|
| 7 |
+
WINDOW_SIZE = 100
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class MetricCounter:
|
| 11 |
+
def __init__(self, exp_name):
|
| 12 |
+
self.writer = SummaryWriter(exp_name)
|
| 13 |
+
logging.basicConfig(filename='{}.log'.format(exp_name), level=logging.DEBUG)
|
| 14 |
+
self.metrics = defaultdict(list)
|
| 15 |
+
self.images = defaultdict(list)
|
| 16 |
+
self.best_metric = 0
|
| 17 |
+
|
| 18 |
+
def add_image(self, x: np.ndarray, tag: str):
|
| 19 |
+
self.images[tag].append(x)
|
| 20 |
+
|
| 21 |
+
def clear(self):
|
| 22 |
+
self.metrics = defaultdict(list)
|
| 23 |
+
self.images = defaultdict(list)
|
| 24 |
+
|
| 25 |
+
def add_losses(self, l_G):
|
| 26 |
+
for name, value in zip(('G_loss', None), (l_G, None)):
|
| 27 |
+
self.metrics[name].append(value)
|
| 28 |
+
|
| 29 |
+
def add_metrics(self, psnr, ssim):
|
| 30 |
+
for name, value in zip(('PSNR', 'SSIM'),
|
| 31 |
+
(psnr, ssim)):
|
| 32 |
+
self.metrics[name].append(value)
|
| 33 |
+
|
| 34 |
+
def loss_message(self):
|
| 35 |
+
metrics = ((k, np.mean(self.metrics[k][-WINDOW_SIZE:])) for k in ('G_loss', 'PSNR', 'SSIM'))
|
| 36 |
+
return '; '.join(map(lambda x: f'{x[0]}={x[1]:.4f}', metrics))
|
| 37 |
+
|
| 38 |
+
def write_to_tensorboard(self, epoch_num, validation=False):
|
| 39 |
+
scalar_prefix = 'Validation' if validation else 'Train'
|
| 40 |
+
for tag in ('G_loss', 'SSIM', 'PSNR'):
|
| 41 |
+
self.writer.add_scalar(f'{scalar_prefix}_{tag}', np.mean(self.metrics[tag]), global_step=epoch_num)
|
| 42 |
+
for tag in self.images:
|
| 43 |
+
imgs = self.images[tag]
|
| 44 |
+
if imgs:
|
| 45 |
+
imgs = np.array(imgs)
|
| 46 |
+
self.writer.add_images(tag, imgs[:, :, :, ::-1].astype('float32') / 255, dataformats='NHWC',
|
| 47 |
+
global_step=epoch_num)
|
| 48 |
+
self.images[tag] = []
|
| 49 |
+
|
| 50 |
+
def update_best_model(self):
|
| 51 |
+
cur_metric = np.mean(self.metrics['PSNR'])
|
| 52 |
+
if self.best_metric < cur_metric:
|
| 53 |
+
self.best_metric = cur_metric
|
| 54 |
+
return True
|
| 55 |
+
return False
|
models/XYScanNet.py
ADDED
|
@@ -0,0 +1,737 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numbers
|
| 2 |
+
import math
|
| 3 |
+
from typing import Optional
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
from torch import Tensor
|
| 9 |
+
|
| 10 |
+
from einops import rearrange, repeat
|
| 11 |
+
|
| 12 |
+
from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn
|
| 13 |
+
|
| 14 |
+
try:
|
| 15 |
+
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
|
| 16 |
+
except ImportError:
|
| 17 |
+
causal_conv1d_fn, causal_conv1d_update = None, None
|
| 18 |
+
|
| 19 |
+
try:
|
| 20 |
+
from mamba_ssm.ops.triton.selective_state_update import selective_state_update
|
| 21 |
+
except ImportError:
|
| 22 |
+
selective_state_update = None
|
| 23 |
+
|
| 24 |
+
try:
|
| 25 |
+
from mamba_ssm.ops.triton.layer_norm import RMSNorm, layer_norm_fn, rms_norm_fn
|
| 26 |
+
except ImportError:
|
| 27 |
+
RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def to_3d(x):
|
| 31 |
+
return rearrange(x, 'b c h w -> b (h w) c')
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def to_4d(x, h, w):
|
| 35 |
+
return rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class BiasFree_LayerNorm(nn.Module):
|
| 39 |
+
def __init__(self, normalized_shape):
|
| 40 |
+
super(BiasFree_LayerNorm, self).__init__()
|
| 41 |
+
if isinstance(normalized_shape, numbers.Integral):
|
| 42 |
+
normalized_shape = (normalized_shape,)
|
| 43 |
+
normalized_shape = torch.Size(normalized_shape)
|
| 44 |
+
|
| 45 |
+
assert len(normalized_shape) == 1
|
| 46 |
+
|
| 47 |
+
self.weight = nn.Parameter(torch.ones(normalized_shape))
|
| 48 |
+
self.normalized_shape = normalized_shape
|
| 49 |
+
|
| 50 |
+
def forward(self, x):
|
| 51 |
+
sigma = x.var(-1, keepdim=True, unbiased=False)
|
| 52 |
+
return x / torch.sqrt(sigma + 1e-5) * self.weight
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class WithBias_LayerNorm(nn.Module):
|
| 56 |
+
def __init__(self, normalized_shape):
|
| 57 |
+
super(WithBias_LayerNorm, self).__init__()
|
| 58 |
+
if isinstance(normalized_shape, numbers.Integral):
|
| 59 |
+
normalized_shape = (normalized_shape,)
|
| 60 |
+
normalized_shape = torch.Size(normalized_shape)
|
| 61 |
+
|
| 62 |
+
assert len(normalized_shape) == 1
|
| 63 |
+
|
| 64 |
+
self.weight = nn.Parameter(torch.ones(normalized_shape))
|
| 65 |
+
self.bias = nn.Parameter(torch.zeros(normalized_shape))
|
| 66 |
+
self.normalized_shape = normalized_shape
|
| 67 |
+
|
| 68 |
+
def forward(self, x):
|
| 69 |
+
mu = x.mean(-1, keepdim=True)
|
| 70 |
+
sigma = x.var(-1, keepdim=True, unbiased=False)
|
| 71 |
+
return (x - mu) / torch.sqrt(sigma + 1e-5) * self.weight + self.bias
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
class LayerNorm(nn.Module):
|
| 75 |
+
def __init__(self, dim, LayerNorm_type):
|
| 76 |
+
super(LayerNorm, self).__init__()
|
| 77 |
+
if LayerNorm_type == 'BiasFree':
|
| 78 |
+
self.body = BiasFree_LayerNorm(dim)
|
| 79 |
+
else:
|
| 80 |
+
self.body = WithBias_LayerNorm(dim)
|
| 81 |
+
|
| 82 |
+
def forward(self, x):
|
| 83 |
+
h, w = x.shape[-2:]
|
| 84 |
+
return to_4d(self.body(to_3d(x)), h, w)
|
| 85 |
+
|
| 86 |
+
##########################################################################
|
| 87 |
+
def conv(in_channels, out_channels, kernel_size, bias=False, stride = 1):
|
| 88 |
+
return nn.Conv2d(
|
| 89 |
+
in_channels, out_channels, kernel_size,
|
| 90 |
+
padding=(kernel_size//2), bias=bias, stride = stride)
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
"""
|
| 94 |
+
Borrow from "https://github.com/state-spaces/mamba.git"
|
| 95 |
+
@article{mamba,
|
| 96 |
+
title={Mamba: Linear-Time Sequence Modeling with Selective State Spaces},
|
| 97 |
+
author={Gu, Albert and Dao, Tri},
|
| 98 |
+
journal={arXiv preprint arXiv:2312.00752},
|
| 99 |
+
year={2023}
|
| 100 |
+
}
|
| 101 |
+
"""
|
| 102 |
+
class Mamba(nn.Module):
|
| 103 |
+
def __init__(
|
| 104 |
+
self,
|
| 105 |
+
d_model,
|
| 106 |
+
d_state=16,
|
| 107 |
+
d_conv=4,
|
| 108 |
+
expand=2,
|
| 109 |
+
dt_rank="auto",
|
| 110 |
+
dt_min=0.001,
|
| 111 |
+
dt_max=0.1,
|
| 112 |
+
dt_init="random",
|
| 113 |
+
dt_scale=1.0,
|
| 114 |
+
dt_init_floor=1e-4,
|
| 115 |
+
conv_bias=True,
|
| 116 |
+
bias=False,
|
| 117 |
+
use_fast_path=True, # Fused kernel options
|
| 118 |
+
layer_idx=None,
|
| 119 |
+
device=None,
|
| 120 |
+
dtype=None,
|
| 121 |
+
):
|
| 122 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
| 123 |
+
super().__init__()
|
| 124 |
+
self.d_model = d_model
|
| 125 |
+
self.d_state = d_state
|
| 126 |
+
self.d_conv = d_conv
|
| 127 |
+
self.expand = expand
|
| 128 |
+
self.d_inner = int(self.expand * self.d_model)
|
| 129 |
+
self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank
|
| 130 |
+
self.use_fast_path = use_fast_path
|
| 131 |
+
self.layer_idx = layer_idx
|
| 132 |
+
|
| 133 |
+
self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=bias, **factory_kwargs)
|
| 134 |
+
|
| 135 |
+
self.conv1d = nn.Conv1d(
|
| 136 |
+
in_channels=self.d_inner,
|
| 137 |
+
out_channels=self.d_inner,
|
| 138 |
+
bias=conv_bias,
|
| 139 |
+
kernel_size=d_conv,
|
| 140 |
+
groups=self.d_inner,
|
| 141 |
+
padding=d_conv - 1,
|
| 142 |
+
**factory_kwargs,
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
self.activation = "silu"
|
| 146 |
+
self.act = nn.SiLU()
|
| 147 |
+
|
| 148 |
+
self.x_proj = nn.Linear(
|
| 149 |
+
self.d_inner, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs
|
| 150 |
+
)
|
| 151 |
+
self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True, **factory_kwargs)
|
| 152 |
+
|
| 153 |
+
# Initialize special dt projection to preserve variance at initialization
|
| 154 |
+
dt_init_std = self.dt_rank**-0.5 * dt_scale
|
| 155 |
+
if dt_init == "constant":
|
| 156 |
+
nn.init.constant_(self.dt_proj.weight, dt_init_std)
|
| 157 |
+
elif dt_init == "random":
|
| 158 |
+
nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std)
|
| 159 |
+
else:
|
| 160 |
+
raise NotImplementedError
|
| 161 |
+
|
| 162 |
+
# Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max
|
| 163 |
+
dt = torch.exp(
|
| 164 |
+
torch.rand(self.d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min))
|
| 165 |
+
+ math.log(dt_min)
|
| 166 |
+
).clamp(min=dt_init_floor)
|
| 167 |
+
# Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
|
| 168 |
+
inv_dt = dt + torch.log(-torch.expm1(-dt))
|
| 169 |
+
with torch.no_grad():
|
| 170 |
+
self.dt_proj.bias.copy_(inv_dt)
|
| 171 |
+
# Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit
|
| 172 |
+
self.dt_proj.bias._no_reinit = True
|
| 173 |
+
|
| 174 |
+
# S4D real initialization
|
| 175 |
+
A = repeat(
|
| 176 |
+
torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device),
|
| 177 |
+
"n -> d n",
|
| 178 |
+
d=self.d_inner,
|
| 179 |
+
).contiguous()
|
| 180 |
+
A_log = torch.log(A) # Keep A_log in fp32
|
| 181 |
+
self.A_log = nn.Parameter(A_log)
|
| 182 |
+
self.A_log._no_weight_decay = True
|
| 183 |
+
|
| 184 |
+
# D "skip" parameter
|
| 185 |
+
self.D = nn.Parameter(torch.ones(self.d_inner, device=device)) # Keep in fp32
|
| 186 |
+
self.D._no_weight_decay = True
|
| 187 |
+
|
| 188 |
+
self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs)
|
| 189 |
+
|
| 190 |
+
def forward(self, hidden_states, inference_params=None):
|
| 191 |
+
"""
|
| 192 |
+
hidden_states: (B, L, D)
|
| 193 |
+
Returns: same shape as hidden_states
|
| 194 |
+
"""
|
| 195 |
+
batch, seqlen, dim = hidden_states.shape
|
| 196 |
+
|
| 197 |
+
conv_state, ssm_state = None, None
|
| 198 |
+
if inference_params is not None:
|
| 199 |
+
conv_state, ssm_state = self._get_states_from_cache(inference_params, batch)
|
| 200 |
+
if inference_params.seqlen_offset > 0:
|
| 201 |
+
# The states are updated inplace
|
| 202 |
+
out, _, _ = self.step(hidden_states, conv_state, ssm_state)
|
| 203 |
+
return out
|
| 204 |
+
|
| 205 |
+
# We do matmul and transpose BLH -> HBL at the same time
|
| 206 |
+
xz = rearrange(
|
| 207 |
+
self.in_proj.weight @ rearrange(hidden_states, "b l d -> d (b l)"),
|
| 208 |
+
"d (b l) -> b d l",
|
| 209 |
+
l=seqlen,
|
| 210 |
+
)
|
| 211 |
+
if self.in_proj.bias is not None:
|
| 212 |
+
xz = xz + rearrange(self.in_proj.bias.to(dtype=xz.dtype), "d -> d 1")
|
| 213 |
+
|
| 214 |
+
A = -torch.exp(self.A_log.float()) # (d_inner, d_state)
|
| 215 |
+
# In the backward pass we write dx and dz next to each other to avoid torch.cat
|
| 216 |
+
if self.use_fast_path and causal_conv1d_fn is not None and inference_params is None: # Doesn't support outputting the states
|
| 217 |
+
out = mamba_inner_fn(
|
| 218 |
+
xz,
|
| 219 |
+
self.conv1d.weight,
|
| 220 |
+
self.conv1d.bias,
|
| 221 |
+
self.x_proj.weight,
|
| 222 |
+
self.dt_proj.weight,
|
| 223 |
+
self.out_proj.weight,
|
| 224 |
+
self.out_proj.bias,
|
| 225 |
+
A,
|
| 226 |
+
None, # input-dependent B
|
| 227 |
+
None, # input-dependent C
|
| 228 |
+
self.D.float(),
|
| 229 |
+
delta_bias=self.dt_proj.bias.float(),
|
| 230 |
+
delta_softplus=True,
|
| 231 |
+
)
|
| 232 |
+
else:
|
| 233 |
+
x, z = xz.chunk(2, dim=1)
|
| 234 |
+
# Compute short convolution
|
| 235 |
+
if conv_state is not None:
|
| 236 |
+
# If we just take x[:, :, -self.d_conv :], it will error if seqlen < self.d_conv
|
| 237 |
+
# Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise.
|
| 238 |
+
conv_state.copy_(F.pad(x, (self.d_conv - x.shape[-1], 0))) # Update state (B D W)
|
| 239 |
+
if causal_conv1d_fn is None:
|
| 240 |
+
x = self.act(self.conv1d(x)[..., :seqlen])
|
| 241 |
+
else:
|
| 242 |
+
assert self.activation in ["silu", "swish"]
|
| 243 |
+
x = causal_conv1d_fn(
|
| 244 |
+
x=x,
|
| 245 |
+
weight=rearrange(self.conv1d.weight, "d 1 w -> d w"),
|
| 246 |
+
bias=self.conv1d.bias,
|
| 247 |
+
activation=self.activation,
|
| 248 |
+
)
|
| 249 |
+
|
| 250 |
+
# We're careful here about the layout, to avoid extra transposes.
|
| 251 |
+
# We want dt to have d as the slowest moving dimension
|
| 252 |
+
# and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
|
| 253 |
+
x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d")) # (bl d)
|
| 254 |
+
dt, B, C = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1)
|
| 255 |
+
dt = self.dt_proj.weight @ dt.t()
|
| 256 |
+
dt = rearrange(dt, "d (b l) -> b d l", l=seqlen)
|
| 257 |
+
B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
|
| 258 |
+
C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
|
| 259 |
+
assert self.activation in ["silu", "swish"]
|
| 260 |
+
y = selective_scan_fn(
|
| 261 |
+
x,
|
| 262 |
+
dt,
|
| 263 |
+
A,
|
| 264 |
+
B,
|
| 265 |
+
C,
|
| 266 |
+
self.D.float(),
|
| 267 |
+
z=z,
|
| 268 |
+
delta_bias=self.dt_proj.bias.float(),
|
| 269 |
+
delta_softplus=True,
|
| 270 |
+
return_last_state=ssm_state is not None,
|
| 271 |
+
)
|
| 272 |
+
if ssm_state is not None:
|
| 273 |
+
y, last_state = y
|
| 274 |
+
ssm_state.copy_(last_state)
|
| 275 |
+
y = rearrange(y, "b d l -> b l d")
|
| 276 |
+
out = self.out_proj(y)
|
| 277 |
+
return out
|
| 278 |
+
|
| 279 |
+
def step(self, hidden_states, conv_state, ssm_state):
|
| 280 |
+
dtype = hidden_states.dtype
|
| 281 |
+
assert hidden_states.shape[1] == 1, "Only support decoding with 1 token at a time for now"
|
| 282 |
+
xz = self.in_proj(hidden_states.squeeze(1)) # (B 2D)
|
| 283 |
+
x, z = xz.chunk(2, dim=-1) # (B D)
|
| 284 |
+
|
| 285 |
+
# Conv step
|
| 286 |
+
if causal_conv1d_update is None:
|
| 287 |
+
conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W)
|
| 288 |
+
conv_state[:, :, -1] = x
|
| 289 |
+
x = torch.sum(conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1) # (B D)
|
| 290 |
+
if self.conv1d.bias is not None:
|
| 291 |
+
x = x + self.conv1d.bias
|
| 292 |
+
x = self.act(x).to(dtype=dtype)
|
| 293 |
+
else:
|
| 294 |
+
x = causal_conv1d_update(
|
| 295 |
+
x,
|
| 296 |
+
conv_state,
|
| 297 |
+
rearrange(self.conv1d.weight, "d 1 w -> d w"),
|
| 298 |
+
self.conv1d.bias,
|
| 299 |
+
self.activation,
|
| 300 |
+
)
|
| 301 |
+
|
| 302 |
+
x_db = self.x_proj(x) # (B dt_rank+2*d_state)
|
| 303 |
+
dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1)
|
| 304 |
+
# Don't add dt_bias here
|
| 305 |
+
dt = F.linear(dt, self.dt_proj.weight) # (B d_inner)
|
| 306 |
+
A = -torch.exp(self.A_log.float()) # (d_inner, d_state)
|
| 307 |
+
|
| 308 |
+
# SSM step
|
| 309 |
+
if selective_state_update is None:
|
| 310 |
+
# Discretize A and B
|
| 311 |
+
dt = F.softplus(dt + self.dt_proj.bias.to(dtype=dt.dtype))
|
| 312 |
+
dA = torch.exp(torch.einsum("bd,dn->bdn", dt, A))
|
| 313 |
+
dB = torch.einsum("bd,bn->bdn", dt, B)
|
| 314 |
+
ssm_state.copy_(ssm_state * dA + rearrange(x, "b d -> b d 1") * dB)
|
| 315 |
+
y = torch.einsum("bdn,bn->bd", ssm_state.to(dtype), C)
|
| 316 |
+
y = y + self.D.to(dtype) * x
|
| 317 |
+
y = y * self.act(z) # (B D)
|
| 318 |
+
else:
|
| 319 |
+
y = selective_state_update(
|
| 320 |
+
ssm_state, x, dt, A, B, C, self.D, z=z, dt_bias=self.dt_proj.bias, dt_softplus=True
|
| 321 |
+
)
|
| 322 |
+
|
| 323 |
+
out = self.out_proj(y)
|
| 324 |
+
return out.unsqueeze(1), conv_state, ssm_state
|
| 325 |
+
|
| 326 |
+
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
| 327 |
+
device = self.out_proj.weight.device
|
| 328 |
+
conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype
|
| 329 |
+
conv_state = torch.zeros(
|
| 330 |
+
batch_size, self.d_model * self.expand, self.d_conv, device=device, dtype=conv_dtype
|
| 331 |
+
)
|
| 332 |
+
ssm_dtype = self.dt_proj.weight.dtype if dtype is None else dtype
|
| 333 |
+
# ssm_dtype = torch.float32
|
| 334 |
+
ssm_state = torch.zeros(
|
| 335 |
+
batch_size, self.d_model * self.expand, self.d_state, device=device, dtype=ssm_dtype
|
| 336 |
+
)
|
| 337 |
+
return conv_state, ssm_state
|
| 338 |
+
|
| 339 |
+
def _get_states_from_cache(self, inference_params, batch_size, initialize_states=False):
|
| 340 |
+
assert self.layer_idx is not None
|
| 341 |
+
if self.layer_idx not in inference_params.key_value_memory_dict:
|
| 342 |
+
batch_shape = (batch_size,)
|
| 343 |
+
conv_state = torch.zeros(
|
| 344 |
+
batch_size,
|
| 345 |
+
self.d_model * self.expand,
|
| 346 |
+
self.d_conv,
|
| 347 |
+
device=self.conv1d.weight.device,
|
| 348 |
+
dtype=self.conv1d.weight.dtype,
|
| 349 |
+
)
|
| 350 |
+
ssm_state = torch.zeros(
|
| 351 |
+
batch_size,
|
| 352 |
+
self.d_model * self.expand,
|
| 353 |
+
self.d_state,
|
| 354 |
+
device=self.dt_proj.weight.device,
|
| 355 |
+
dtype=self.dt_proj.weight.dtype,
|
| 356 |
+
# dtype=torch.float32,
|
| 357 |
+
)
|
| 358 |
+
inference_params.key_value_memory_dict[self.layer_idx] = (conv_state, ssm_state)
|
| 359 |
+
else:
|
| 360 |
+
conv_state, ssm_state = inference_params.key_value_memory_dict[self.layer_idx]
|
| 361 |
+
# TODO: What if batch size changes between generation, and we reuse the same states?
|
| 362 |
+
if initialize_states:
|
| 363 |
+
conv_state.zero_()
|
| 364 |
+
ssm_state.zero_()
|
| 365 |
+
return conv_state, ssm_state
|
| 366 |
+
|
| 367 |
+
##########################################################################
|
| 368 |
+
## Feed-forward Network
|
| 369 |
+
class FFN(nn.Module):
|
| 370 |
+
def __init__(self, dim, ffn_expansion_factor, bias):
|
| 371 |
+
super(FFN, self).__init__()
|
| 372 |
+
|
| 373 |
+
hidden_features = int(dim*ffn_expansion_factor)
|
| 374 |
+
|
| 375 |
+
self.project_in = nn.Conv2d(dim, hidden_features*2, kernel_size=1, bias=bias)
|
| 376 |
+
|
| 377 |
+
self.dwconv = nn.Conv2d(hidden_features*2, hidden_features*2, kernel_size=3, stride=1, padding=1, groups=hidden_features*2, bias=bias, dilation=1)
|
| 378 |
+
|
| 379 |
+
self.win_size = 8
|
| 380 |
+
|
| 381 |
+
self.modulator = nn.Parameter(torch.ones(self.win_size, self.win_size, dim*2)) # modulator
|
| 382 |
+
|
| 383 |
+
self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias)
|
| 384 |
+
|
| 385 |
+
def forward(self, x):
|
| 386 |
+
b, c, h, w = x.shape
|
| 387 |
+
h1, w1 = h//self.win_size, w//self.win_size
|
| 388 |
+
x = self.project_in(x)
|
| 389 |
+
x = self.dwconv(x)
|
| 390 |
+
x_win = rearrange(x, 'b c (wsh h1) (wsw w1) -> b h1 w1 wsh wsw c', wsh=self.win_size, wsw=self.win_size)
|
| 391 |
+
x_win = x_win * self.modulator
|
| 392 |
+
x = rearrange(x_win, 'b h1 w1 wsh wsw c -> b c (wsh h1) (wsw w1)', wsh=self.win_size, wsw=self.win_size, h1=h1, w1=w1)
|
| 393 |
+
x1, x2 = x.chunk(2, dim=1)
|
| 394 |
+
x = x1 * x2
|
| 395 |
+
x = self.project_out(x)
|
| 396 |
+
return x
|
| 397 |
+
|
| 398 |
+
|
| 399 |
+
##########################################################################
|
| 400 |
+
## Gated Depth-wise Feed-forward Network (GDFN)
|
| 401 |
+
class GDFN(nn.Module):
|
| 402 |
+
def __init__(self, dim, ffn_expansion_factor, bias):
|
| 403 |
+
super(GDFN, self).__init__()
|
| 404 |
+
|
| 405 |
+
hidden_features = int(dim*ffn_expansion_factor)
|
| 406 |
+
|
| 407 |
+
self.project_in = nn.Conv2d(dim, hidden_features*2, kernel_size=1, bias=bias)
|
| 408 |
+
|
| 409 |
+
self.dwconv = nn.Conv2d(hidden_features*2, hidden_features*2, kernel_size=3, stride=1, padding=1, groups=hidden_features*2, bias=bias, dilation=1)
|
| 410 |
+
|
| 411 |
+
self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias)
|
| 412 |
+
|
| 413 |
+
def forward(self, x):
|
| 414 |
+
x = self.project_in(x)
|
| 415 |
+
x = self.dwconv(x)
|
| 416 |
+
x1, x2 = x.chunk(2, dim=1)
|
| 417 |
+
x = F.silu(x1) * x2
|
| 418 |
+
x = self.project_out(x)
|
| 419 |
+
return x
|
| 420 |
+
|
| 421 |
+
|
| 422 |
+
##########################################################################
|
| 423 |
+
## Overlapped image patch embedding with 3x3 Conv
|
| 424 |
+
class OverlapPatchEmbed(nn.Module):
|
| 425 |
+
def __init__(self, in_c=3, embed_dim=48, bias=False):
|
| 426 |
+
super(OverlapPatchEmbed, self).__init__()
|
| 427 |
+
|
| 428 |
+
self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=3, stride=1, padding=1, bias=bias)
|
| 429 |
+
|
| 430 |
+
def forward(self, x):
|
| 431 |
+
x = self.proj(x)
|
| 432 |
+
|
| 433 |
+
return x
|
| 434 |
+
|
| 435 |
+
|
| 436 |
+
##########################################################################
|
| 437 |
+
## Resizing modules
|
| 438 |
+
class Downsample(nn.Module):
|
| 439 |
+
def __init__(self, n_feat):
|
| 440 |
+
super(Downsample, self).__init__()
|
| 441 |
+
|
| 442 |
+
self.body = nn.Sequential(nn.Upsample(scale_factor=0.5, mode='bilinear', align_corners=False),
|
| 443 |
+
nn.Conv2d(n_feat, n_feat * 2, 3, stride=1, padding=1, bias=False))
|
| 444 |
+
|
| 445 |
+
def forward(self, x):
|
| 446 |
+
return self.body(x)
|
| 447 |
+
|
| 448 |
+
|
| 449 |
+
class Upsample(nn.Module):
|
| 450 |
+
def __init__(self, n_feat):
|
| 451 |
+
super(Upsample, self).__init__()
|
| 452 |
+
|
| 453 |
+
self.body = nn.Sequential(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
|
| 454 |
+
nn.Conv2d(n_feat, n_feat // 2, 3, stride=1, padding=1, bias=False))
|
| 455 |
+
|
| 456 |
+
def forward(self, x):
|
| 457 |
+
return self.body(x)
|
| 458 |
+
|
| 459 |
+
|
| 460 |
+
"""
|
| 461 |
+
Borrow from "https://github.com/pp00704831/Stripformer-ECCV-2022-.git"
|
| 462 |
+
@inproceedings{Tsai2022Stripformer,
|
| 463 |
+
author = {Fu-Jen Tsai and Yan-Tsung Peng and Yen-Yu Lin and Chung-Chi Tsai and Chia-Wen Lin},
|
| 464 |
+
title = {Stripformer: Strip Transformer for Fast Image Deblurring},
|
| 465 |
+
booktitle = {ECCV},
|
| 466 |
+
year = {2022}
|
| 467 |
+
}
|
| 468 |
+
"""
|
| 469 |
+
class Intra_VSSM(nn.Module):
|
| 470 |
+
def __init__(self, dim, vssm_expansion_factor, bias): # gated = True
|
| 471 |
+
super(Intra_VSSM, self).__init__()
|
| 472 |
+
hidden = int(dim*vssm_expansion_factor)
|
| 473 |
+
|
| 474 |
+
self.proj_in = nn.Conv2d(dim, hidden*2, kernel_size=1, bias=bias)
|
| 475 |
+
self.dwconv = nn.Conv2d(hidden*2, hidden*2, kernel_size=3, stride=1, padding=1, groups=hidden*2, bias=bias)
|
| 476 |
+
self.proj_out = nn.Conv2d(hidden, dim, kernel_size=1, bias=bias)
|
| 477 |
+
|
| 478 |
+
self.conv_input = nn.Conv2d(hidden, hidden, kernel_size=1, padding=0, bias=bias)
|
| 479 |
+
self.fuse_out = nn.Conv2d(hidden, hidden, kernel_size=1, padding=0, bias=bias)
|
| 480 |
+
self.mamba = Mamba(d_model=hidden // 2)
|
| 481 |
+
|
| 482 |
+
def forward_core(self, x):
|
| 483 |
+
B, C, H, W = x.size()
|
| 484 |
+
|
| 485 |
+
x_input = torch.chunk(self.conv_input(x), 2, dim=1)
|
| 486 |
+
|
| 487 |
+
feature_h = (x_input[0]).permute(0, 2, 3, 1).contiguous()
|
| 488 |
+
feature_h = feature_h.view(B * H, W, C//2)
|
| 489 |
+
|
| 490 |
+
feature_v = (x_input[1]).permute(0, 3, 2, 1).contiguous()
|
| 491 |
+
feature_v = feature_v.view(B * W, H, C//2)
|
| 492 |
+
|
| 493 |
+
if H == W:
|
| 494 |
+
feature = torch.cat((feature_h, feature_v), dim=0) # B * H * 2, W, C//2
|
| 495 |
+
scan_output = self.mamba(feature)
|
| 496 |
+
scan_output = torch.chunk(scan_output, 2, dim=0)
|
| 497 |
+
scan_output_h = scan_output[0]
|
| 498 |
+
scan_output_v = scan_output[1]
|
| 499 |
+
else:
|
| 500 |
+
scan_output_h = self.mamba(feature_h)
|
| 501 |
+
scan_output_v = self.mamba(feature_v)
|
| 502 |
+
|
| 503 |
+
scan_output_h = scan_output_h.view(B, H, W, C//2).permute(0, 3, 1, 2).contiguous()
|
| 504 |
+
scan_output_v = scan_output_v.view(B, W, H, C//2).permute(0, 3, 2, 1).contiguous()
|
| 505 |
+
scan_output = self.fuse_out(torch.cat((scan_output_h, scan_output_v), dim=1))
|
| 506 |
+
|
| 507 |
+
return scan_output
|
| 508 |
+
|
| 509 |
+
def forward(self, x):
|
| 510 |
+
x = self.proj_in(x)
|
| 511 |
+
x, x_ = self.dwconv(x).chunk(2, dim=1)
|
| 512 |
+
x = self.forward_core(x)
|
| 513 |
+
x = F.silu(x_) * x
|
| 514 |
+
x = self.proj_out(x)
|
| 515 |
+
return x
|
| 516 |
+
|
| 517 |
+
|
| 518 |
+
class Inter_VSSM(nn.Module):
|
| 519 |
+
def __init__(self, dim, vssm_expansion_factor, bias): # gated = True
|
| 520 |
+
super(Inter_VSSM, self).__init__()
|
| 521 |
+
hidden = int(dim*vssm_expansion_factor)
|
| 522 |
+
|
| 523 |
+
self.proj_in = nn.Conv2d(dim, hidden*2, kernel_size=1, bias=bias)
|
| 524 |
+
self.dwconv = nn.Conv2d(hidden*2, hidden*2, kernel_size=3, stride=1, padding=1, groups=hidden*2, bias=bias)
|
| 525 |
+
self.proj_out = nn.Conv2d(hidden, dim, kernel_size=1, bias=bias)
|
| 526 |
+
|
| 527 |
+
self.avg_pool = nn.AdaptiveAvgPool2d((None,1))
|
| 528 |
+
self.conv_input = nn.Conv2d(hidden, hidden, kernel_size=1, padding=0, bias=bias)
|
| 529 |
+
self.fuse_out = nn.Conv2d(hidden, hidden, kernel_size=1, padding=0, bias=bias)
|
| 530 |
+
self.mamba = Mamba(d_model=hidden // 2)
|
| 531 |
+
self.sigmoid = nn.Sigmoid()
|
| 532 |
+
|
| 533 |
+
def forward_core(self, x):
|
| 534 |
+
B, C, H, W = x.size()
|
| 535 |
+
|
| 536 |
+
x_input = torch.chunk(self.conv_input(x), 2, dim=1) # B, C, H, W
|
| 537 |
+
|
| 538 |
+
feature_h = x_input[0].permute(0, 2, 1, 3).contiguous() # B, H, C//2, W
|
| 539 |
+
feature_h_score = self.avg_pool(feature_h) # B, H, C//2, 1
|
| 540 |
+
feature_h_score = feature_h_score.view(B, H, -1)
|
| 541 |
+
|
| 542 |
+
feature_v = x_input[1].permute(0, 3, 1, 2).contiguous() # B, W, C//2, H
|
| 543 |
+
feature_v_score = self.avg_pool(feature_v) # B, W, C//2, 1
|
| 544 |
+
feature_v_score = feature_v_score.view(B, W, -1)
|
| 545 |
+
|
| 546 |
+
if H == W:
|
| 547 |
+
feature_score = torch.cat((feature_h_score, feature_v_score), dim=0) # B * 2, W or H, C//2
|
| 548 |
+
scan_score = self.mamba(feature_score)
|
| 549 |
+
scan_score = torch.chunk(scan_score, 2, dim=0)
|
| 550 |
+
scan_score_h = scan_score[0]
|
| 551 |
+
scan_score_v = scan_score[1]
|
| 552 |
+
else:
|
| 553 |
+
scan_score_h = self.mamba(feature_h_score)
|
| 554 |
+
scan_score_v = self.mamba(feature_v_score)
|
| 555 |
+
|
| 556 |
+
scan_score_h = self.sigmoid(scan_score_h)
|
| 557 |
+
scan_score_v = self.sigmoid(scan_score_v)
|
| 558 |
+
feature_h = feature_h*scan_score_h[:,:,:,None]
|
| 559 |
+
feature_v = feature_v*scan_score_v[:,:,:,None]
|
| 560 |
+
feature_h = feature_h.view(B, H, C//2, W).permute(0, 2, 1, 3).contiguous()
|
| 561 |
+
feature_v = feature_v.view(B, W, C//2, H).permute(0, 2, 3, 1).contiguous()
|
| 562 |
+
output = self.fuse_out(torch.cat((feature_h, feature_v), dim=1))
|
| 563 |
+
|
| 564 |
+
return output
|
| 565 |
+
|
| 566 |
+
def forward(self, x):
|
| 567 |
+
x = self.proj_in(x)
|
| 568 |
+
x, x_ = self.dwconv(x).chunk(2, dim=1)
|
| 569 |
+
x = self.forward_core(x)
|
| 570 |
+
x = F.silu(x_) * x
|
| 571 |
+
x = self.proj_out(x)
|
| 572 |
+
return x
|
| 573 |
+
|
| 574 |
+
|
| 575 |
+
##########################################################################
|
| 576 |
+
class Strip_VSSB(nn.Module):
|
| 577 |
+
def __init__(self, dim, vssm_expansion_factor, ffn_expansion_factor, bias=False, ssm=False, LayerNorm_type='WithBias'):
|
| 578 |
+
super(Strip_VSSB, self).__init__()
|
| 579 |
+
self.ssm = ssm
|
| 580 |
+
if self.ssm == True:
|
| 581 |
+
self.norm1_ssm = LayerNorm(dim, LayerNorm_type)
|
| 582 |
+
self.norm2_ssm = LayerNorm(dim, LayerNorm_type)
|
| 583 |
+
self.intra = Intra_VSSM(dim, vssm_expansion_factor, bias)
|
| 584 |
+
self.inter = Inter_VSSM(dim, vssm_expansion_factor, bias)
|
| 585 |
+
self.norm1_ffn = LayerNorm(dim, LayerNorm_type)
|
| 586 |
+
self.norm2_ffn = LayerNorm(dim, LayerNorm_type)
|
| 587 |
+
self.ffn1 = GDFN(dim, ffn_expansion_factor, bias)
|
| 588 |
+
self.ffn2 = GDFN(dim, ffn_expansion_factor, bias)
|
| 589 |
+
|
| 590 |
+
def forward(self, x):
|
| 591 |
+
if self.ssm == True:
|
| 592 |
+
x = x + self.intra(self.norm1_ssm(x))
|
| 593 |
+
x = x + self.ffn1(self.norm1_ffn(x))
|
| 594 |
+
if self.ssm == True:
|
| 595 |
+
x = x + self.inter(self.norm2_ssm(x))
|
| 596 |
+
x = x + self.ffn2(self.norm2_ffn(x))
|
| 597 |
+
|
| 598 |
+
return x
|
| 599 |
+
|
| 600 |
+
|
| 601 |
+
##########################################################################
|
| 602 |
+
##---------- Cross-level Feature Fusion by Adding Sigmoid(KL-Div) * Multi-Scale Feat -----------------------
|
| 603 |
+
class CLFF(nn.Module):
|
| 604 |
+
def __init__(self, dim, dim_n1, dim_n2, bias=False):
|
| 605 |
+
super(CLFF, self).__init__()
|
| 606 |
+
|
| 607 |
+
self.conv = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
|
| 608 |
+
self.conv_n1 = nn.Conv2d(dim_n1, dim, kernel_size=1, bias=bias)
|
| 609 |
+
self.conv_n2 = nn.Conv2d(dim_n2, dim, kernel_size=1, bias=bias)
|
| 610 |
+
self.fuse_out1 = nn.Conv2d(dim*2, dim, kernel_size=1, bias=bias)
|
| 611 |
+
|
| 612 |
+
self.log_sigmoid = nn.LogSigmoid()
|
| 613 |
+
self.sigmoid = nn.Sigmoid()
|
| 614 |
+
|
| 615 |
+
def forward(self, x, n1, n2):
|
| 616 |
+
x_ = self.conv(x)
|
| 617 |
+
n1_ = self.conv_n1(n1)
|
| 618 |
+
n2_ = self.conv_n2(n2)
|
| 619 |
+
kl_n1 = F.kl_div(input=self.log_sigmoid(n1_), target=self.log_sigmoid(x_), log_target=True)
|
| 620 |
+
kl_n2 = F.kl_div(input=self.log_sigmoid(n2_), target=self.log_sigmoid(x_), log_target=True)
|
| 621 |
+
#g = self.sigmoid(x_)
|
| 622 |
+
g1 = self.sigmoid(kl_n1)
|
| 623 |
+
g2 = self.sigmoid(kl_n2)
|
| 624 |
+
#x = (1 + g) * x_ + (1 - g) * (g1 * n1_ + g2 * n2_)
|
| 625 |
+
x = self.fuse_out1(torch.cat((x_, g1 * n1_ + g2 * n2_), dim=1))
|
| 626 |
+
|
| 627 |
+
return x
|
| 628 |
+
|
| 629 |
+
##########################################################################
|
| 630 |
+
##---------- StripScanNet -----------------------
|
| 631 |
+
class XYScanNet(nn.Module):
|
| 632 |
+
def __init__(self,
|
| 633 |
+
inp_channels=3,
|
| 634 |
+
out_channels=3,
|
| 635 |
+
dim = 72, # 48, 72, 96, 120, 144
|
| 636 |
+
num_blocks = [3,3,6],
|
| 637 |
+
vssm_expansion_factor = 1, # 1 or 2
|
| 638 |
+
ffn_expansion_factor = 1, # 1 or 3
|
| 639 |
+
bias = False,
|
| 640 |
+
LayerNorm_type = 'WithBias', ## Other option 'BiasFree'
|
| 641 |
+
):
|
| 642 |
+
|
| 643 |
+
super(XYScanNet, self).__init__()
|
| 644 |
+
|
| 645 |
+
self.patch_embed = OverlapPatchEmbed(inp_channels, dim)
|
| 646 |
+
|
| 647 |
+
self.encoder_level1 = nn.Sequential(*[Strip_VSSB(dim=dim, vssm_expansion_factor=vssm_expansion_factor, ffn_expansion_factor = ffn_expansion_factor,
|
| 648 |
+
bias=bias, ssm=False, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])])
|
| 649 |
+
|
| 650 |
+
self.down1_2 = Downsample(dim) ## From Level 1 to Level 2
|
| 651 |
+
self.encoder_level2 = nn.Sequential(*[Strip_VSSB(dim=int(dim*2**1), vssm_expansion_factor=vssm_expansion_factor, ffn_expansion_factor = ffn_expansion_factor,
|
| 652 |
+
bias=bias, ssm=False, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[1])])
|
| 653 |
+
|
| 654 |
+
self.down2_3 = Downsample(int(dim*2**1)) ## From Level 2 to Level 3
|
| 655 |
+
self.encoder_level3 = nn.Sequential(*[Strip_VSSB(dim=int(dim*2**2), vssm_expansion_factor=vssm_expansion_factor, ffn_expansion_factor = ffn_expansion_factor,
|
| 656 |
+
bias=bias, ssm=False, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[2])])
|
| 657 |
+
|
| 658 |
+
self.decoder_level3 = nn.Sequential(*[Strip_VSSB(dim=int(dim*2**2), vssm_expansion_factor=vssm_expansion_factor, ffn_expansion_factor = ffn_expansion_factor,
|
| 659 |
+
bias=bias, ssm=True, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[2])])
|
| 660 |
+
|
| 661 |
+
self.up3_2 = Upsample(int(dim*2**2)) ## From Level 3 to Level 2
|
| 662 |
+
self.clff_level2 = CLFF(int(dim*2**1), dim_n1=int(dim*2**0), dim_n2=(dim*2**2), bias=bias)
|
| 663 |
+
self.reduce_chan_level2 = nn.Conv2d(int(dim*2**2), int(dim*2**1), kernel_size=1, bias=bias)
|
| 664 |
+
self.decoder_level2 = nn.Sequential(*[Strip_VSSB(dim=int(dim*2**1), vssm_expansion_factor=vssm_expansion_factor, ffn_expansion_factor = ffn_expansion_factor,
|
| 665 |
+
bias=bias, ssm=True, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[1])])
|
| 666 |
+
|
| 667 |
+
self.up2_1 = Upsample(int(dim*2**1)) ## From Level 2 to Level 1
|
| 668 |
+
self.clff_level1 = CLFF(int(dim*2**0), dim_n1=int(dim*2**1), dim_n2=(dim*2**2), bias=bias)
|
| 669 |
+
self.reduce_chan_level1 = nn.Conv2d(int(dim*2**1), int(dim*2**0), kernel_size=1, bias=bias)
|
| 670 |
+
self.decoder_level1 = nn.Sequential(*[Strip_VSSB(dim=int(dim*2**0), vssm_expansion_factor=vssm_expansion_factor, ffn_expansion_factor = ffn_expansion_factor,
|
| 671 |
+
bias=bias, ssm=True, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])])
|
| 672 |
+
|
| 673 |
+
# self.refinement = nn.Sequential(*[Strip_VSSB(dim=int(dim*2**0), expansion_factor=expansion_factor, bias=bias, ssm=True, LayerNorm_type=LayerNorm_type) for i in range(num_refinement_blocks)])
|
| 674 |
+
|
| 675 |
+
self.output = nn.Conv2d(int(dim*2**0), out_channels, kernel_size=3, stride=1, padding=1, bias=bias)
|
| 676 |
+
|
| 677 |
+
def forward(self, inp_img):
|
| 678 |
+
|
| 679 |
+
# Encoder
|
| 680 |
+
inp_enc_level1 = self.patch_embed(inp_img)
|
| 681 |
+
out_enc_level1 = self.encoder_level1(inp_enc_level1)
|
| 682 |
+
out_enc_level1_2 = F.interpolate(out_enc_level1, scale_factor=0.5) # dim*2, lvl1 down-scaled to lvl2
|
| 683 |
+
|
| 684 |
+
inp_enc_level2 = self.down1_2(out_enc_level1)
|
| 685 |
+
out_enc_level2 = self.encoder_level2(inp_enc_level2)
|
| 686 |
+
out_enc_level2_1 = F.interpolate(out_enc_level2, scale_factor=2) # dim*2, lvl2 up-scaled to lvl1
|
| 687 |
+
|
| 688 |
+
inp_enc_level3 = self.down2_3(out_enc_level2)
|
| 689 |
+
out_enc_level3 = self.encoder_level3(inp_enc_level3)
|
| 690 |
+
out_enc_level3_2 = F.interpolate(out_enc_level3, scale_factor=2) # dim*2**2, lvl3 up-scaled to lvl2 (lvl3->lvl2)
|
| 691 |
+
out_enc_level3_1 = F.interpolate(out_enc_level3_2, scale_factor=2) # dim*2**2, lvl3 up-scaled to lvl1 (lvl3->lvl2->lvl1)
|
| 692 |
+
|
| 693 |
+
out_enc_level1 = self.clff_level1(out_enc_level1, out_enc_level2_1, out_enc_level3_1)
|
| 694 |
+
out_enc_level2 = self.clff_level2(out_enc_level2, out_enc_level1_2, out_enc_level3_2)
|
| 695 |
+
|
| 696 |
+
# Decoder
|
| 697 |
+
out_dec_level3_decomp1 = self.decoder_level3(out_enc_level3)
|
| 698 |
+
|
| 699 |
+
inp_dec_level2_decomp1 = self.up3_2(out_dec_level3_decomp1)
|
| 700 |
+
inp_dec_level2_decomp1 = self.reduce_chan_level2(torch.cat((inp_dec_level2_decomp1, out_enc_level2), dim=1))
|
| 701 |
+
out_dec_level2_decomp1 = self.decoder_level2(inp_dec_level2_decomp1)
|
| 702 |
+
|
| 703 |
+
inp_dec_level1_decomp1 = self.up2_1(out_dec_level2_decomp1)
|
| 704 |
+
inp_dec_level1_decomp1 = self.reduce_chan_level1(torch.cat((inp_dec_level1_decomp1, out_enc_level1), dim=1))
|
| 705 |
+
out_dec_level1_decomp1 = self.decoder_level1(inp_dec_level1_decomp1)
|
| 706 |
+
|
| 707 |
+
out_dec_level1_decomp1 = self.output(out_dec_level1_decomp1)
|
| 708 |
+
|
| 709 |
+
out_dec_level1 = out_dec_level1_decomp1 + inp_img
|
| 710 |
+
|
| 711 |
+
|
| 712 |
+
return out_dec_level1, out_dec_level1_decomp1, None
|
| 713 |
+
|
| 714 |
+
def count_parameters(model):
|
| 715 |
+
total = sum(p.numel() for p in model.parameters())
|
| 716 |
+
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 717 |
+
print(f"Total parameters: {total:,}")
|
| 718 |
+
print(f"Trainable parameters: {trainable:,}")
|
| 719 |
+
|
| 720 |
+
|
| 721 |
+
def main():
|
| 722 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 723 |
+
model = XYScanNet().to(device)
|
| 724 |
+
|
| 725 |
+
print("Model architecture:\n")
|
| 726 |
+
print(model)
|
| 727 |
+
|
| 728 |
+
count_parameters(model)
|
| 729 |
+
|
| 730 |
+
# Optionally test with a dummy input
|
| 731 |
+
dummy_input = torch.randn(1, 3, 256, 256).to(device)
|
| 732 |
+
output, _, _ = model(dummy_input)
|
| 733 |
+
print(f"Output shape: {output.shape}")
|
| 734 |
+
|
| 735 |
+
|
| 736 |
+
if __name__ == "__main__":
|
| 737 |
+
main()
|
models/XYScanNetP.py
ADDED
|
@@ -0,0 +1,737 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numbers
|
| 2 |
+
import math
|
| 3 |
+
from typing import Optional
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
from torch import Tensor
|
| 9 |
+
|
| 10 |
+
from einops import rearrange, repeat
|
| 11 |
+
|
| 12 |
+
from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn
|
| 13 |
+
|
| 14 |
+
try:
|
| 15 |
+
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
|
| 16 |
+
except ImportError:
|
| 17 |
+
causal_conv1d_fn, causal_conv1d_update = None, None
|
| 18 |
+
|
| 19 |
+
try:
|
| 20 |
+
from mamba_ssm.ops.triton.selective_state_update import selective_state_update
|
| 21 |
+
except ImportError:
|
| 22 |
+
selective_state_update = None
|
| 23 |
+
|
| 24 |
+
try:
|
| 25 |
+
from mamba_ssm.ops.triton.layer_norm import RMSNorm, layer_norm_fn, rms_norm_fn
|
| 26 |
+
except ImportError:
|
| 27 |
+
RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def to_3d(x):
|
| 31 |
+
return rearrange(x, 'b c h w -> b (h w) c')
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def to_4d(x, h, w):
|
| 35 |
+
return rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class BiasFree_LayerNorm(nn.Module):
|
| 39 |
+
def __init__(self, normalized_shape):
|
| 40 |
+
super(BiasFree_LayerNorm, self).__init__()
|
| 41 |
+
if isinstance(normalized_shape, numbers.Integral):
|
| 42 |
+
normalized_shape = (normalized_shape,)
|
| 43 |
+
normalized_shape = torch.Size(normalized_shape)
|
| 44 |
+
|
| 45 |
+
assert len(normalized_shape) == 1
|
| 46 |
+
|
| 47 |
+
self.weight = nn.Parameter(torch.ones(normalized_shape))
|
| 48 |
+
self.normalized_shape = normalized_shape
|
| 49 |
+
|
| 50 |
+
def forward(self, x):
|
| 51 |
+
sigma = x.var(-1, keepdim=True, unbiased=False)
|
| 52 |
+
return x / torch.sqrt(sigma + 1e-5) * self.weight
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class WithBias_LayerNorm(nn.Module):
|
| 56 |
+
def __init__(self, normalized_shape):
|
| 57 |
+
super(WithBias_LayerNorm, self).__init__()
|
| 58 |
+
if isinstance(normalized_shape, numbers.Integral):
|
| 59 |
+
normalized_shape = (normalized_shape,)
|
| 60 |
+
normalized_shape = torch.Size(normalized_shape)
|
| 61 |
+
|
| 62 |
+
assert len(normalized_shape) == 1
|
| 63 |
+
|
| 64 |
+
self.weight = nn.Parameter(torch.ones(normalized_shape))
|
| 65 |
+
self.bias = nn.Parameter(torch.zeros(normalized_shape))
|
| 66 |
+
self.normalized_shape = normalized_shape
|
| 67 |
+
|
| 68 |
+
def forward(self, x):
|
| 69 |
+
mu = x.mean(-1, keepdim=True)
|
| 70 |
+
sigma = x.var(-1, keepdim=True, unbiased=False)
|
| 71 |
+
return (x - mu) / torch.sqrt(sigma + 1e-5) * self.weight + self.bias
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
class LayerNorm(nn.Module):
|
| 75 |
+
def __init__(self, dim, LayerNorm_type):
|
| 76 |
+
super(LayerNorm, self).__init__()
|
| 77 |
+
if LayerNorm_type == 'BiasFree':
|
| 78 |
+
self.body = BiasFree_LayerNorm(dim)
|
| 79 |
+
else:
|
| 80 |
+
self.body = WithBias_LayerNorm(dim)
|
| 81 |
+
|
| 82 |
+
def forward(self, x):
|
| 83 |
+
h, w = x.shape[-2:]
|
| 84 |
+
return to_4d(self.body(to_3d(x)), h, w)
|
| 85 |
+
|
| 86 |
+
##########################################################################
|
| 87 |
+
def conv(in_channels, out_channels, kernel_size, bias=False, stride = 1):
|
| 88 |
+
return nn.Conv2d(
|
| 89 |
+
in_channels, out_channels, kernel_size,
|
| 90 |
+
padding=(kernel_size//2), bias=bias, stride = stride)
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
"""
|
| 94 |
+
Borrow from "https://github.com/state-spaces/mamba.git"
|
| 95 |
+
@article{mamba,
|
| 96 |
+
title={Mamba: Linear-Time Sequence Modeling with Selective State Spaces},
|
| 97 |
+
author={Gu, Albert and Dao, Tri},
|
| 98 |
+
journal={arXiv preprint arXiv:2312.00752},
|
| 99 |
+
year={2023}
|
| 100 |
+
}
|
| 101 |
+
"""
|
| 102 |
+
class Mamba(nn.Module):
|
| 103 |
+
def __init__(
|
| 104 |
+
self,
|
| 105 |
+
d_model,
|
| 106 |
+
d_state=16,
|
| 107 |
+
d_conv=4,
|
| 108 |
+
expand=2,
|
| 109 |
+
dt_rank="auto",
|
| 110 |
+
dt_min=0.001,
|
| 111 |
+
dt_max=0.1,
|
| 112 |
+
dt_init="random",
|
| 113 |
+
dt_scale=1.0,
|
| 114 |
+
dt_init_floor=1e-4,
|
| 115 |
+
conv_bias=True,
|
| 116 |
+
bias=False,
|
| 117 |
+
use_fast_path=True, # Fused kernel options
|
| 118 |
+
layer_idx=None,
|
| 119 |
+
device=None,
|
| 120 |
+
dtype=None,
|
| 121 |
+
):
|
| 122 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
| 123 |
+
super().__init__()
|
| 124 |
+
self.d_model = d_model
|
| 125 |
+
self.d_state = d_state
|
| 126 |
+
self.d_conv = d_conv
|
| 127 |
+
self.expand = expand
|
| 128 |
+
self.d_inner = int(self.expand * self.d_model)
|
| 129 |
+
self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank
|
| 130 |
+
self.use_fast_path = use_fast_path
|
| 131 |
+
self.layer_idx = layer_idx
|
| 132 |
+
|
| 133 |
+
self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=bias, **factory_kwargs)
|
| 134 |
+
|
| 135 |
+
self.conv1d = nn.Conv1d(
|
| 136 |
+
in_channels=self.d_inner,
|
| 137 |
+
out_channels=self.d_inner,
|
| 138 |
+
bias=conv_bias,
|
| 139 |
+
kernel_size=d_conv,
|
| 140 |
+
groups=self.d_inner,
|
| 141 |
+
padding=d_conv - 1,
|
| 142 |
+
**factory_kwargs,
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
self.activation = "silu"
|
| 146 |
+
self.act = nn.SiLU()
|
| 147 |
+
|
| 148 |
+
self.x_proj = nn.Linear(
|
| 149 |
+
self.d_inner, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs
|
| 150 |
+
)
|
| 151 |
+
self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True, **factory_kwargs)
|
| 152 |
+
|
| 153 |
+
# Initialize special dt projection to preserve variance at initialization
|
| 154 |
+
dt_init_std = self.dt_rank**-0.5 * dt_scale
|
| 155 |
+
if dt_init == "constant":
|
| 156 |
+
nn.init.constant_(self.dt_proj.weight, dt_init_std)
|
| 157 |
+
elif dt_init == "random":
|
| 158 |
+
nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std)
|
| 159 |
+
else:
|
| 160 |
+
raise NotImplementedError
|
| 161 |
+
|
| 162 |
+
# Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max
|
| 163 |
+
dt = torch.exp(
|
| 164 |
+
torch.rand(self.d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min))
|
| 165 |
+
+ math.log(dt_min)
|
| 166 |
+
).clamp(min=dt_init_floor)
|
| 167 |
+
# Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
|
| 168 |
+
inv_dt = dt + torch.log(-torch.expm1(-dt))
|
| 169 |
+
with torch.no_grad():
|
| 170 |
+
self.dt_proj.bias.copy_(inv_dt)
|
| 171 |
+
# Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit
|
| 172 |
+
self.dt_proj.bias._no_reinit = True
|
| 173 |
+
|
| 174 |
+
# S4D real initialization
|
| 175 |
+
A = repeat(
|
| 176 |
+
torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device),
|
| 177 |
+
"n -> d n",
|
| 178 |
+
d=self.d_inner,
|
| 179 |
+
).contiguous()
|
| 180 |
+
A_log = torch.log(A) # Keep A_log in fp32
|
| 181 |
+
self.A_log = nn.Parameter(A_log)
|
| 182 |
+
self.A_log._no_weight_decay = True
|
| 183 |
+
|
| 184 |
+
# D "skip" parameter
|
| 185 |
+
self.D = nn.Parameter(torch.ones(self.d_inner, device=device)) # Keep in fp32
|
| 186 |
+
self.D._no_weight_decay = True
|
| 187 |
+
|
| 188 |
+
self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs)
|
| 189 |
+
|
| 190 |
+
def forward(self, hidden_states, inference_params=None):
|
| 191 |
+
"""
|
| 192 |
+
hidden_states: (B, L, D)
|
| 193 |
+
Returns: same shape as hidden_states
|
| 194 |
+
"""
|
| 195 |
+
batch, seqlen, dim = hidden_states.shape
|
| 196 |
+
|
| 197 |
+
conv_state, ssm_state = None, None
|
| 198 |
+
if inference_params is not None:
|
| 199 |
+
conv_state, ssm_state = self._get_states_from_cache(inference_params, batch)
|
| 200 |
+
if inference_params.seqlen_offset > 0:
|
| 201 |
+
# The states are updated inplace
|
| 202 |
+
out, _, _ = self.step(hidden_states, conv_state, ssm_state)
|
| 203 |
+
return out
|
| 204 |
+
|
| 205 |
+
# We do matmul and transpose BLH -> HBL at the same time
|
| 206 |
+
xz = rearrange(
|
| 207 |
+
self.in_proj.weight @ rearrange(hidden_states, "b l d -> d (b l)"),
|
| 208 |
+
"d (b l) -> b d l",
|
| 209 |
+
l=seqlen,
|
| 210 |
+
)
|
| 211 |
+
if self.in_proj.bias is not None:
|
| 212 |
+
xz = xz + rearrange(self.in_proj.bias.to(dtype=xz.dtype), "d -> d 1")
|
| 213 |
+
|
| 214 |
+
A = -torch.exp(self.A_log.float()) # (d_inner, d_state)
|
| 215 |
+
# In the backward pass we write dx and dz next to each other to avoid torch.cat
|
| 216 |
+
if self.use_fast_path and causal_conv1d_fn is not None and inference_params is None: # Doesn't support outputting the states
|
| 217 |
+
out = mamba_inner_fn(
|
| 218 |
+
xz,
|
| 219 |
+
self.conv1d.weight,
|
| 220 |
+
self.conv1d.bias,
|
| 221 |
+
self.x_proj.weight,
|
| 222 |
+
self.dt_proj.weight,
|
| 223 |
+
self.out_proj.weight,
|
| 224 |
+
self.out_proj.bias,
|
| 225 |
+
A,
|
| 226 |
+
None, # input-dependent B
|
| 227 |
+
None, # input-dependent C
|
| 228 |
+
self.D.float(),
|
| 229 |
+
delta_bias=self.dt_proj.bias.float(),
|
| 230 |
+
delta_softplus=True,
|
| 231 |
+
)
|
| 232 |
+
else:
|
| 233 |
+
x, z = xz.chunk(2, dim=1)
|
| 234 |
+
# Compute short convolution
|
| 235 |
+
if conv_state is not None:
|
| 236 |
+
# If we just take x[:, :, -self.d_conv :], it will error if seqlen < self.d_conv
|
| 237 |
+
# Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise.
|
| 238 |
+
conv_state.copy_(F.pad(x, (self.d_conv - x.shape[-1], 0))) # Update state (B D W)
|
| 239 |
+
if causal_conv1d_fn is None:
|
| 240 |
+
x = self.act(self.conv1d(x)[..., :seqlen])
|
| 241 |
+
else:
|
| 242 |
+
assert self.activation in ["silu", "swish"]
|
| 243 |
+
x = causal_conv1d_fn(
|
| 244 |
+
x=x,
|
| 245 |
+
weight=rearrange(self.conv1d.weight, "d 1 w -> d w"),
|
| 246 |
+
bias=self.conv1d.bias,
|
| 247 |
+
activation=self.activation,
|
| 248 |
+
)
|
| 249 |
+
|
| 250 |
+
# We're careful here about the layout, to avoid extra transposes.
|
| 251 |
+
# We want dt to have d as the slowest moving dimension
|
| 252 |
+
# and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
|
| 253 |
+
x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d")) # (bl d)
|
| 254 |
+
dt, B, C = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1)
|
| 255 |
+
dt = self.dt_proj.weight @ dt.t()
|
| 256 |
+
dt = rearrange(dt, "d (b l) -> b d l", l=seqlen)
|
| 257 |
+
B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
|
| 258 |
+
C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
|
| 259 |
+
assert self.activation in ["silu", "swish"]
|
| 260 |
+
y = selective_scan_fn(
|
| 261 |
+
x,
|
| 262 |
+
dt,
|
| 263 |
+
A,
|
| 264 |
+
B,
|
| 265 |
+
C,
|
| 266 |
+
self.D.float(),
|
| 267 |
+
z=z,
|
| 268 |
+
delta_bias=self.dt_proj.bias.float(),
|
| 269 |
+
delta_softplus=True,
|
| 270 |
+
return_last_state=ssm_state is not None,
|
| 271 |
+
)
|
| 272 |
+
if ssm_state is not None:
|
| 273 |
+
y, last_state = y
|
| 274 |
+
ssm_state.copy_(last_state)
|
| 275 |
+
y = rearrange(y, "b d l -> b l d")
|
| 276 |
+
out = self.out_proj(y)
|
| 277 |
+
return out
|
| 278 |
+
|
| 279 |
+
def step(self, hidden_states, conv_state, ssm_state):
|
| 280 |
+
dtype = hidden_states.dtype
|
| 281 |
+
assert hidden_states.shape[1] == 1, "Only support decoding with 1 token at a time for now"
|
| 282 |
+
xz = self.in_proj(hidden_states.squeeze(1)) # (B 2D)
|
| 283 |
+
x, z = xz.chunk(2, dim=-1) # (B D)
|
| 284 |
+
|
| 285 |
+
# Conv step
|
| 286 |
+
if causal_conv1d_update is None:
|
| 287 |
+
conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W)
|
| 288 |
+
conv_state[:, :, -1] = x
|
| 289 |
+
x = torch.sum(conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1) # (B D)
|
| 290 |
+
if self.conv1d.bias is not None:
|
| 291 |
+
x = x + self.conv1d.bias
|
| 292 |
+
x = self.act(x).to(dtype=dtype)
|
| 293 |
+
else:
|
| 294 |
+
x = causal_conv1d_update(
|
| 295 |
+
x,
|
| 296 |
+
conv_state,
|
| 297 |
+
rearrange(self.conv1d.weight, "d 1 w -> d w"),
|
| 298 |
+
self.conv1d.bias,
|
| 299 |
+
self.activation,
|
| 300 |
+
)
|
| 301 |
+
|
| 302 |
+
x_db = self.x_proj(x) # (B dt_rank+2*d_state)
|
| 303 |
+
dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1)
|
| 304 |
+
# Don't add dt_bias here
|
| 305 |
+
dt = F.linear(dt, self.dt_proj.weight) # (B d_inner)
|
| 306 |
+
A = -torch.exp(self.A_log.float()) # (d_inner, d_state)
|
| 307 |
+
|
| 308 |
+
# SSM step
|
| 309 |
+
if selective_state_update is None:
|
| 310 |
+
# Discretize A and B
|
| 311 |
+
dt = F.softplus(dt + self.dt_proj.bias.to(dtype=dt.dtype))
|
| 312 |
+
dA = torch.exp(torch.einsum("bd,dn->bdn", dt, A))
|
| 313 |
+
dB = torch.einsum("bd,bn->bdn", dt, B)
|
| 314 |
+
ssm_state.copy_(ssm_state * dA + rearrange(x, "b d -> b d 1") * dB)
|
| 315 |
+
y = torch.einsum("bdn,bn->bd", ssm_state.to(dtype), C)
|
| 316 |
+
y = y + self.D.to(dtype) * x
|
| 317 |
+
y = y * self.act(z) # (B D)
|
| 318 |
+
else:
|
| 319 |
+
y = selective_state_update(
|
| 320 |
+
ssm_state, x, dt, A, B, C, self.D, z=z, dt_bias=self.dt_proj.bias, dt_softplus=True
|
| 321 |
+
)
|
| 322 |
+
|
| 323 |
+
out = self.out_proj(y)
|
| 324 |
+
return out.unsqueeze(1), conv_state, ssm_state
|
| 325 |
+
|
| 326 |
+
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
| 327 |
+
device = self.out_proj.weight.device
|
| 328 |
+
conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype
|
| 329 |
+
conv_state = torch.zeros(
|
| 330 |
+
batch_size, self.d_model * self.expand, self.d_conv, device=device, dtype=conv_dtype
|
| 331 |
+
)
|
| 332 |
+
ssm_dtype = self.dt_proj.weight.dtype if dtype is None else dtype
|
| 333 |
+
# ssm_dtype = torch.float32
|
| 334 |
+
ssm_state = torch.zeros(
|
| 335 |
+
batch_size, self.d_model * self.expand, self.d_state, device=device, dtype=ssm_dtype
|
| 336 |
+
)
|
| 337 |
+
return conv_state, ssm_state
|
| 338 |
+
|
| 339 |
+
def _get_states_from_cache(self, inference_params, batch_size, initialize_states=False):
|
| 340 |
+
assert self.layer_idx is not None
|
| 341 |
+
if self.layer_idx not in inference_params.key_value_memory_dict:
|
| 342 |
+
batch_shape = (batch_size,)
|
| 343 |
+
conv_state = torch.zeros(
|
| 344 |
+
batch_size,
|
| 345 |
+
self.d_model * self.expand,
|
| 346 |
+
self.d_conv,
|
| 347 |
+
device=self.conv1d.weight.device,
|
| 348 |
+
dtype=self.conv1d.weight.dtype,
|
| 349 |
+
)
|
| 350 |
+
ssm_state = torch.zeros(
|
| 351 |
+
batch_size,
|
| 352 |
+
self.d_model * self.expand,
|
| 353 |
+
self.d_state,
|
| 354 |
+
device=self.dt_proj.weight.device,
|
| 355 |
+
dtype=self.dt_proj.weight.dtype,
|
| 356 |
+
# dtype=torch.float32,
|
| 357 |
+
)
|
| 358 |
+
inference_params.key_value_memory_dict[self.layer_idx] = (conv_state, ssm_state)
|
| 359 |
+
else:
|
| 360 |
+
conv_state, ssm_state = inference_params.key_value_memory_dict[self.layer_idx]
|
| 361 |
+
# TODO: What if batch size changes between generation, and we reuse the same states?
|
| 362 |
+
if initialize_states:
|
| 363 |
+
conv_state.zero_()
|
| 364 |
+
ssm_state.zero_()
|
| 365 |
+
return conv_state, ssm_state
|
| 366 |
+
|
| 367 |
+
##########################################################################
|
| 368 |
+
## Feed-forward Network
|
| 369 |
+
class FFN(nn.Module):
|
| 370 |
+
def __init__(self, dim, ffn_expansion_factor, bias):
|
| 371 |
+
super(FFN, self).__init__()
|
| 372 |
+
|
| 373 |
+
hidden_features = int(dim*ffn_expansion_factor)
|
| 374 |
+
|
| 375 |
+
self.project_in = nn.Conv2d(dim, hidden_features*2, kernel_size=1, bias=bias)
|
| 376 |
+
|
| 377 |
+
self.dwconv = nn.Conv2d(hidden_features*2, hidden_features*2, kernel_size=3, stride=1, padding=1, groups=hidden_features*2, bias=bias, dilation=1)
|
| 378 |
+
|
| 379 |
+
self.win_size = 8
|
| 380 |
+
|
| 381 |
+
self.modulator = nn.Parameter(torch.ones(self.win_size, self.win_size, dim*2)) # modulator
|
| 382 |
+
|
| 383 |
+
self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias)
|
| 384 |
+
|
| 385 |
+
def forward(self, x):
|
| 386 |
+
b, c, h, w = x.shape
|
| 387 |
+
h1, w1 = h//self.win_size, w//self.win_size
|
| 388 |
+
x = self.project_in(x)
|
| 389 |
+
x = self.dwconv(x)
|
| 390 |
+
x_win = rearrange(x, 'b c (wsh h1) (wsw w1) -> b h1 w1 wsh wsw c', wsh=self.win_size, wsw=self.win_size)
|
| 391 |
+
x_win = x_win * self.modulator
|
| 392 |
+
x = rearrange(x_win, 'b h1 w1 wsh wsw c -> b c (wsh h1) (wsw w1)', wsh=self.win_size, wsw=self.win_size, h1=h1, w1=w1)
|
| 393 |
+
x1, x2 = x.chunk(2, dim=1)
|
| 394 |
+
x = x1 * x2
|
| 395 |
+
x = self.project_out(x)
|
| 396 |
+
return x
|
| 397 |
+
|
| 398 |
+
|
| 399 |
+
##########################################################################
|
| 400 |
+
## Gated Depth-wise Feed-forward Network (GDFN)
|
| 401 |
+
class GDFN(nn.Module):
|
| 402 |
+
def __init__(self, dim, ffn_expansion_factor, bias):
|
| 403 |
+
super(GDFN, self).__init__()
|
| 404 |
+
|
| 405 |
+
hidden_features = int(dim*ffn_expansion_factor)
|
| 406 |
+
|
| 407 |
+
self.project_in = nn.Conv2d(dim, hidden_features*2, kernel_size=1, bias=bias)
|
| 408 |
+
|
| 409 |
+
self.dwconv = nn.Conv2d(hidden_features*2, hidden_features*2, kernel_size=3, stride=1, padding=1, groups=hidden_features*2, bias=bias, dilation=1)
|
| 410 |
+
|
| 411 |
+
self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias)
|
| 412 |
+
|
| 413 |
+
def forward(self, x):
|
| 414 |
+
x = self.project_in(x)
|
| 415 |
+
x = self.dwconv(x)
|
| 416 |
+
x1, x2 = x.chunk(2, dim=1)
|
| 417 |
+
x = F.silu(x1) * x2
|
| 418 |
+
x = self.project_out(x)
|
| 419 |
+
return x
|
| 420 |
+
|
| 421 |
+
|
| 422 |
+
##########################################################################
|
| 423 |
+
## Overlapped image patch embedding with 3x3 Conv
|
| 424 |
+
class OverlapPatchEmbed(nn.Module):
|
| 425 |
+
def __init__(self, in_c=3, embed_dim=48, bias=False):
|
| 426 |
+
super(OverlapPatchEmbed, self).__init__()
|
| 427 |
+
|
| 428 |
+
self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=3, stride=1, padding=1, bias=bias)
|
| 429 |
+
|
| 430 |
+
def forward(self, x):
|
| 431 |
+
x = self.proj(x)
|
| 432 |
+
|
| 433 |
+
return x
|
| 434 |
+
|
| 435 |
+
|
| 436 |
+
##########################################################################
|
| 437 |
+
## Resizing modules
|
| 438 |
+
class Downsample(nn.Module):
|
| 439 |
+
def __init__(self, n_feat):
|
| 440 |
+
super(Downsample, self).__init__()
|
| 441 |
+
|
| 442 |
+
self.body = nn.Sequential(nn.Upsample(scale_factor=0.5, mode='bilinear', align_corners=False),
|
| 443 |
+
nn.Conv2d(n_feat, n_feat * 2, 3, stride=1, padding=1, bias=False))
|
| 444 |
+
|
| 445 |
+
def forward(self, x):
|
| 446 |
+
return self.body(x)
|
| 447 |
+
|
| 448 |
+
|
| 449 |
+
class Upsample(nn.Module):
|
| 450 |
+
def __init__(self, n_feat):
|
| 451 |
+
super(Upsample, self).__init__()
|
| 452 |
+
|
| 453 |
+
self.body = nn.Sequential(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
|
| 454 |
+
nn.Conv2d(n_feat, n_feat // 2, 3, stride=1, padding=1, bias=False))
|
| 455 |
+
|
| 456 |
+
def forward(self, x):
|
| 457 |
+
return self.body(x)
|
| 458 |
+
|
| 459 |
+
|
| 460 |
+
"""
|
| 461 |
+
Borrow from "https://github.com/pp00704831/Stripformer-ECCV-2022-.git"
|
| 462 |
+
@inproceedings{Tsai2022Stripformer,
|
| 463 |
+
author = {Fu-Jen Tsai and Yan-Tsung Peng and Yen-Yu Lin and Chung-Chi Tsai and Chia-Wen Lin},
|
| 464 |
+
title = {Stripformer: Strip Transformer for Fast Image Deblurring},
|
| 465 |
+
booktitle = {ECCV},
|
| 466 |
+
year = {2022}
|
| 467 |
+
}
|
| 468 |
+
"""
|
| 469 |
+
class Intra_VSSM(nn.Module):
|
| 470 |
+
def __init__(self, dim, vssm_expansion_factor, bias): # gated = True
|
| 471 |
+
super(Intra_VSSM, self).__init__()
|
| 472 |
+
hidden = int(dim*vssm_expansion_factor)
|
| 473 |
+
|
| 474 |
+
self.proj_in = nn.Conv2d(dim, hidden*2, kernel_size=1, bias=bias)
|
| 475 |
+
self.dwconv = nn.Conv2d(hidden*2, hidden*2, kernel_size=3, stride=1, padding=1, groups=hidden*2, bias=bias)
|
| 476 |
+
self.proj_out = nn.Conv2d(hidden, dim, kernel_size=1, bias=bias)
|
| 477 |
+
|
| 478 |
+
self.conv_input = nn.Conv2d(hidden, hidden, kernel_size=1, padding=0, bias=bias)
|
| 479 |
+
self.fuse_out = nn.Conv2d(hidden, hidden, kernel_size=1, padding=0, bias=bias)
|
| 480 |
+
self.mamba = Mamba(d_model=hidden // 2)
|
| 481 |
+
|
| 482 |
+
def forward_core(self, x):
|
| 483 |
+
B, C, H, W = x.size()
|
| 484 |
+
|
| 485 |
+
x_input = torch.chunk(self.conv_input(x), 2, dim=1)
|
| 486 |
+
|
| 487 |
+
feature_h = (x_input[0]).permute(0, 2, 3, 1).contiguous()
|
| 488 |
+
feature_h = feature_h.view(B * H, W, C//2)
|
| 489 |
+
|
| 490 |
+
feature_v = (x_input[1]).permute(0, 3, 2, 1).contiguous()
|
| 491 |
+
feature_v = feature_v.view(B * W, H, C//2)
|
| 492 |
+
|
| 493 |
+
if H == W:
|
| 494 |
+
feature = torch.cat((feature_h, feature_v), dim=0) # B * H * 2, W, C//2
|
| 495 |
+
scan_output = self.mamba(feature)
|
| 496 |
+
scan_output = torch.chunk(scan_output, 2, dim=0)
|
| 497 |
+
scan_output_h = scan_output[0]
|
| 498 |
+
scan_output_v = scan_output[1]
|
| 499 |
+
else:
|
| 500 |
+
scan_output_h = self.mamba(feature_h)
|
| 501 |
+
scan_output_v = self.mamba(feature_v)
|
| 502 |
+
|
| 503 |
+
scan_output_h = scan_output_h.view(B, H, W, C//2).permute(0, 3, 1, 2).contiguous()
|
| 504 |
+
scan_output_v = scan_output_v.view(B, W, H, C//2).permute(0, 3, 2, 1).contiguous()
|
| 505 |
+
scan_output = self.fuse_out(torch.cat((scan_output_h, scan_output_v), dim=1))
|
| 506 |
+
|
| 507 |
+
return scan_output
|
| 508 |
+
|
| 509 |
+
def forward(self, x):
|
| 510 |
+
x = self.proj_in(x)
|
| 511 |
+
x, x_ = self.dwconv(x).chunk(2, dim=1)
|
| 512 |
+
x = self.forward_core(x)
|
| 513 |
+
x = F.silu(x_) * x
|
| 514 |
+
x = self.proj_out(x)
|
| 515 |
+
return x
|
| 516 |
+
|
| 517 |
+
|
| 518 |
+
class Inter_VSSM(nn.Module):
|
| 519 |
+
def __init__(self, dim, vssm_expansion_factor, bias): # gated = True
|
| 520 |
+
super(Inter_VSSM, self).__init__()
|
| 521 |
+
hidden = int(dim*vssm_expansion_factor)
|
| 522 |
+
|
| 523 |
+
self.proj_in = nn.Conv2d(dim, hidden*2, kernel_size=1, bias=bias)
|
| 524 |
+
self.dwconv = nn.Conv2d(hidden*2, hidden*2, kernel_size=3, stride=1, padding=1, groups=hidden*2, bias=bias)
|
| 525 |
+
self.proj_out = nn.Conv2d(hidden, dim, kernel_size=1, bias=bias)
|
| 526 |
+
|
| 527 |
+
self.avg_pool = nn.AdaptiveAvgPool2d((None,1))
|
| 528 |
+
self.conv_input = nn.Conv2d(hidden, hidden, kernel_size=1, padding=0, bias=bias)
|
| 529 |
+
self.fuse_out = nn.Conv2d(hidden, hidden, kernel_size=1, padding=0, bias=bias)
|
| 530 |
+
self.mamba = Mamba(d_model=hidden // 2)
|
| 531 |
+
self.sigmoid = nn.Sigmoid()
|
| 532 |
+
|
| 533 |
+
def forward_core(self, x):
|
| 534 |
+
B, C, H, W = x.size()
|
| 535 |
+
|
| 536 |
+
x_input = torch.chunk(self.conv_input(x), 2, dim=1) # B, C, H, W
|
| 537 |
+
|
| 538 |
+
feature_h = x_input[0].permute(0, 2, 1, 3).contiguous() # B, H, C//2, W
|
| 539 |
+
feature_h_score = self.avg_pool(feature_h) # B, H, C//2, 1
|
| 540 |
+
feature_h_score = feature_h_score.view(B, H, -1)
|
| 541 |
+
|
| 542 |
+
feature_v = x_input[1].permute(0, 3, 1, 2).contiguous() # B, W, C//2, H
|
| 543 |
+
feature_v_score = self.avg_pool(feature_v) # B, W, C//2, 1
|
| 544 |
+
feature_v_score = feature_v_score.view(B, W, -1)
|
| 545 |
+
|
| 546 |
+
if H == W:
|
| 547 |
+
feature_score = torch.cat((feature_h_score, feature_v_score), dim=0) # B * 2, W or H, C//2
|
| 548 |
+
scan_score = self.mamba(feature_score)
|
| 549 |
+
scan_score = torch.chunk(scan_score, 2, dim=0)
|
| 550 |
+
scan_score_h = scan_score[0]
|
| 551 |
+
scan_score_v = scan_score[1]
|
| 552 |
+
else:
|
| 553 |
+
scan_score_h = self.mamba(feature_h_score)
|
| 554 |
+
scan_score_v = self.mamba(feature_v_score)
|
| 555 |
+
|
| 556 |
+
scan_score_h = self.sigmoid(scan_score_h)
|
| 557 |
+
scan_score_v = self.sigmoid(scan_score_v)
|
| 558 |
+
feature_h = feature_h*scan_score_h[:,:,:,None]
|
| 559 |
+
feature_v = feature_v*scan_score_v[:,:,:,None]
|
| 560 |
+
feature_h = feature_h.view(B, H, C//2, W).permute(0, 2, 1, 3).contiguous()
|
| 561 |
+
feature_v = feature_v.view(B, W, C//2, H).permute(0, 2, 3, 1).contiguous()
|
| 562 |
+
output = self.fuse_out(torch.cat((feature_h, feature_v), dim=1))
|
| 563 |
+
|
| 564 |
+
return output
|
| 565 |
+
|
| 566 |
+
def forward(self, x):
|
| 567 |
+
x = self.proj_in(x)
|
| 568 |
+
x, x_ = self.dwconv(x).chunk(2, dim=1)
|
| 569 |
+
x = self.forward_core(x)
|
| 570 |
+
x = F.silu(x_) * x
|
| 571 |
+
x = self.proj_out(x)
|
| 572 |
+
return x
|
| 573 |
+
|
| 574 |
+
|
| 575 |
+
##########################################################################
|
| 576 |
+
class Strip_VSSB(nn.Module):
|
| 577 |
+
def __init__(self, dim, vssm_expansion_factor, ffn_expansion_factor, bias=False, ssm=False, LayerNorm_type='WithBias'):
|
| 578 |
+
super(Strip_VSSB, self).__init__()
|
| 579 |
+
self.ssm = ssm
|
| 580 |
+
if self.ssm == True:
|
| 581 |
+
self.norm1_ssm = LayerNorm(dim, LayerNorm_type)
|
| 582 |
+
self.norm2_ssm = LayerNorm(dim, LayerNorm_type)
|
| 583 |
+
self.intra = Intra_VSSM(dim, vssm_expansion_factor, bias)
|
| 584 |
+
self.inter = Inter_VSSM(dim, vssm_expansion_factor, bias)
|
| 585 |
+
self.norm1_ffn = LayerNorm(dim, LayerNorm_type)
|
| 586 |
+
self.norm2_ffn = LayerNorm(dim, LayerNorm_type)
|
| 587 |
+
self.ffn1 = GDFN(dim, ffn_expansion_factor, bias)
|
| 588 |
+
self.ffn2 = GDFN(dim, ffn_expansion_factor, bias)
|
| 589 |
+
|
| 590 |
+
def forward(self, x):
|
| 591 |
+
if self.ssm == True:
|
| 592 |
+
x = x + self.intra(self.norm1_ssm(x))
|
| 593 |
+
x = x + self.ffn1(self.norm1_ffn(x))
|
| 594 |
+
if self.ssm == True:
|
| 595 |
+
x = x + self.inter(self.norm2_ssm(x))
|
| 596 |
+
x = x + self.ffn2(self.norm2_ffn(x))
|
| 597 |
+
|
| 598 |
+
return x
|
| 599 |
+
|
| 600 |
+
|
| 601 |
+
##########################################################################
|
| 602 |
+
##---------- Cross-level Feature Fusion by Adding Sigmoid(KL-Div) * Multi-Scale Feat -----------------------
|
| 603 |
+
class CLFF(nn.Module):
|
| 604 |
+
def __init__(self, dim, dim_n1, dim_n2, bias=False):
|
| 605 |
+
super(CLFF, self).__init__()
|
| 606 |
+
|
| 607 |
+
self.conv = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
|
| 608 |
+
self.conv_n1 = nn.Conv2d(dim_n1, dim, kernel_size=1, bias=bias)
|
| 609 |
+
self.conv_n2 = nn.Conv2d(dim_n2, dim, kernel_size=1, bias=bias)
|
| 610 |
+
self.fuse_out1 = nn.Conv2d(dim*2, dim, kernel_size=1, bias=bias)
|
| 611 |
+
|
| 612 |
+
self.log_sigmoid = nn.LogSigmoid()
|
| 613 |
+
self.sigmoid = nn.Sigmoid()
|
| 614 |
+
|
| 615 |
+
def forward(self, x, n1, n2):
|
| 616 |
+
x_ = self.conv(x)
|
| 617 |
+
n1_ = self.conv_n1(n1)
|
| 618 |
+
n2_ = self.conv_n2(n2)
|
| 619 |
+
kl_n1 = F.kl_div(input=self.log_sigmoid(n1_), target=self.log_sigmoid(x_), log_target=True)
|
| 620 |
+
kl_n2 = F.kl_div(input=self.log_sigmoid(n2_), target=self.log_sigmoid(x_), log_target=True)
|
| 621 |
+
#g = self.sigmoid(x_)
|
| 622 |
+
g1 = self.sigmoid(kl_n1)
|
| 623 |
+
g2 = self.sigmoid(kl_n2)
|
| 624 |
+
#x = (1 + g) * x_ + (1 - g) * (g1 * n1_ + g2 * n2_)
|
| 625 |
+
x = self.fuse_out1(torch.cat((x_, g1 * n1_ + g2 * n2_), dim=1))
|
| 626 |
+
|
| 627 |
+
return x
|
| 628 |
+
|
| 629 |
+
##########################################################################
|
| 630 |
+
##---------- StripScanNet -----------------------
|
| 631 |
+
class XYScanNetP(nn.Module):
|
| 632 |
+
def __init__(self,
|
| 633 |
+
inp_channels=3,
|
| 634 |
+
out_channels=3,
|
| 635 |
+
dim = 144, # 48, 72, 96, 120, 144
|
| 636 |
+
num_blocks = [3,3,6],
|
| 637 |
+
vssm_expansion_factor = 1, # 1 or 2
|
| 638 |
+
ffn_expansion_factor = 1, # 1 or 3
|
| 639 |
+
bias = False,
|
| 640 |
+
LayerNorm_type = 'WithBias', ## Other option 'BiasFree'
|
| 641 |
+
):
|
| 642 |
+
|
| 643 |
+
super(XYScanNetP, self).__init__()
|
| 644 |
+
|
| 645 |
+
self.patch_embed = OverlapPatchEmbed(inp_channels, dim)
|
| 646 |
+
|
| 647 |
+
self.encoder_level1 = nn.Sequential(*[Strip_VSSB(dim=dim, vssm_expansion_factor=vssm_expansion_factor, ffn_expansion_factor = ffn_expansion_factor,
|
| 648 |
+
bias=bias, ssm=False, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])])
|
| 649 |
+
|
| 650 |
+
self.down1_2 = Downsample(dim) ## From Level 1 to Level 2
|
| 651 |
+
self.encoder_level2 = nn.Sequential(*[Strip_VSSB(dim=int(dim*2**1), vssm_expansion_factor=vssm_expansion_factor, ffn_expansion_factor = ffn_expansion_factor,
|
| 652 |
+
bias=bias, ssm=False, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[1])])
|
| 653 |
+
|
| 654 |
+
self.down2_3 = Downsample(int(dim*2**1)) ## From Level 2 to Level 3
|
| 655 |
+
self.encoder_level3 = nn.Sequential(*[Strip_VSSB(dim=int(dim*2**2), vssm_expansion_factor=vssm_expansion_factor, ffn_expansion_factor = ffn_expansion_factor,
|
| 656 |
+
bias=bias, ssm=False, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[2])])
|
| 657 |
+
|
| 658 |
+
self.decoder_level3 = nn.Sequential(*[Strip_VSSB(dim=int(dim*2**2), vssm_expansion_factor=vssm_expansion_factor, ffn_expansion_factor = ffn_expansion_factor,
|
| 659 |
+
bias=bias, ssm=True, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[2])])
|
| 660 |
+
|
| 661 |
+
self.up3_2 = Upsample(int(dim*2**2)) ## From Level 3 to Level 2
|
| 662 |
+
self.clff_level2 = CLFF(int(dim*2**1), dim_n1=int(dim*2**0), dim_n2=(dim*2**2), bias=bias)
|
| 663 |
+
self.reduce_chan_level2 = nn.Conv2d(int(dim*2**2), int(dim*2**1), kernel_size=1, bias=bias)
|
| 664 |
+
self.decoder_level2 = nn.Sequential(*[Strip_VSSB(dim=int(dim*2**1), vssm_expansion_factor=vssm_expansion_factor, ffn_expansion_factor = ffn_expansion_factor,
|
| 665 |
+
bias=bias, ssm=True, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[1])])
|
| 666 |
+
|
| 667 |
+
self.up2_1 = Upsample(int(dim*2**1)) ## From Level 2 to Level 1
|
| 668 |
+
self.clff_level1 = CLFF(int(dim*2**0), dim_n1=int(dim*2**1), dim_n2=(dim*2**2), bias=bias)
|
| 669 |
+
self.reduce_chan_level1 = nn.Conv2d(int(dim*2**1), int(dim*2**0), kernel_size=1, bias=bias)
|
| 670 |
+
self.decoder_level1 = nn.Sequential(*[Strip_VSSB(dim=int(dim*2**0), vssm_expansion_factor=vssm_expansion_factor, ffn_expansion_factor = ffn_expansion_factor,
|
| 671 |
+
bias=bias, ssm=True, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])])
|
| 672 |
+
|
| 673 |
+
# self.refinement = nn.Sequential(*[Strip_VSSB(dim=int(dim*2**0), expansion_factor=expansion_factor, bias=bias, ssm=True, LayerNorm_type=LayerNorm_type) for i in range(num_refinement_blocks)])
|
| 674 |
+
|
| 675 |
+
self.output = nn.Conv2d(int(dim*2**0), out_channels, kernel_size=3, stride=1, padding=1, bias=bias)
|
| 676 |
+
|
| 677 |
+
def forward(self, inp_img):
|
| 678 |
+
|
| 679 |
+
# Encoder
|
| 680 |
+
inp_enc_level1 = self.patch_embed(inp_img)
|
| 681 |
+
out_enc_level1 = self.encoder_level1(inp_enc_level1)
|
| 682 |
+
out_enc_level1_2 = F.interpolate(out_enc_level1, scale_factor=0.5) # dim*2, lvl1 down-scaled to lvl2
|
| 683 |
+
|
| 684 |
+
inp_enc_level2 = self.down1_2(out_enc_level1)
|
| 685 |
+
out_enc_level2 = self.encoder_level2(inp_enc_level2)
|
| 686 |
+
out_enc_level2_1 = F.interpolate(out_enc_level2, scale_factor=2) # dim*2, lvl2 up-scaled to lvl1
|
| 687 |
+
|
| 688 |
+
inp_enc_level3 = self.down2_3(out_enc_level2)
|
| 689 |
+
out_enc_level3 = self.encoder_level3(inp_enc_level3)
|
| 690 |
+
out_enc_level3_2 = F.interpolate(out_enc_level3, scale_factor=2) # dim*2**2, lvl3 up-scaled to lvl2 (lvl3->lvl2)
|
| 691 |
+
out_enc_level3_1 = F.interpolate(out_enc_level3_2, scale_factor=2) # dim*2**2, lvl3 up-scaled to lvl1 (lvl3->lvl2->lvl1)
|
| 692 |
+
|
| 693 |
+
out_enc_level1 = self.clff_level1(out_enc_level1, out_enc_level2_1, out_enc_level3_1)
|
| 694 |
+
out_enc_level2 = self.clff_level2(out_enc_level2, out_enc_level1_2, out_enc_level3_2)
|
| 695 |
+
|
| 696 |
+
# Decoder
|
| 697 |
+
out_dec_level3_decomp1 = self.decoder_level3(out_enc_level3)
|
| 698 |
+
|
| 699 |
+
inp_dec_level2_decomp1 = self.up3_2(out_dec_level3_decomp1)
|
| 700 |
+
inp_dec_level2_decomp1 = self.reduce_chan_level2(torch.cat((inp_dec_level2_decomp1, out_enc_level2), dim=1))
|
| 701 |
+
out_dec_level2_decomp1 = self.decoder_level2(inp_dec_level2_decomp1)
|
| 702 |
+
|
| 703 |
+
inp_dec_level1_decomp1 = self.up2_1(out_dec_level2_decomp1)
|
| 704 |
+
inp_dec_level1_decomp1 = self.reduce_chan_level1(torch.cat((inp_dec_level1_decomp1, out_enc_level1), dim=1))
|
| 705 |
+
out_dec_level1_decomp1 = self.decoder_level1(inp_dec_level1_decomp1)
|
| 706 |
+
|
| 707 |
+
out_dec_level1_decomp1 = self.output(out_dec_level1_decomp1)
|
| 708 |
+
|
| 709 |
+
out_dec_level1 = out_dec_level1_decomp1 + inp_img
|
| 710 |
+
|
| 711 |
+
|
| 712 |
+
return out_dec_level1, out_dec_level1_decomp1, None
|
| 713 |
+
|
| 714 |
+
def count_parameters(model):
|
| 715 |
+
total = sum(p.numel() for p in model.parameters())
|
| 716 |
+
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 717 |
+
print(f"Total parameters: {total:,}")
|
| 718 |
+
print(f"Trainable parameters: {trainable:,}")
|
| 719 |
+
|
| 720 |
+
|
| 721 |
+
def main():
|
| 722 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 723 |
+
model = XYScanNetP().to(device)
|
| 724 |
+
|
| 725 |
+
print("Model architecture:\n")
|
| 726 |
+
print(model)
|
| 727 |
+
|
| 728 |
+
count_parameters(model)
|
| 729 |
+
|
| 730 |
+
# Optionally test with a dummy input
|
| 731 |
+
dummy_input = torch.randn(1, 3, 256, 256).to(device)
|
| 732 |
+
output, _, _ = model(dummy_input)
|
| 733 |
+
print(f"Output shape: {output.shape}")
|
| 734 |
+
|
| 735 |
+
|
| 736 |
+
if __name__ == "__main__":
|
| 737 |
+
main()
|
models/__init__.py
ADDED
|
File without changes
|
models/__pycache__/XYScanNet.cpython-38.pyc
ADDED
|
Binary file (21.2 kB). View file
|
|
|
models/__pycache__/XYScanNetP.cpython-38.pyc
ADDED
|
Binary file (21.2 kB). View file
|
|
|
models/__pycache__/__init__.cpython-38.pyc
ADDED
|
Binary file (140 Bytes). View file
|
|
|
models/__pycache__/networks.cpython-38.pyc
ADDED
|
Binary file (691 Bytes). View file
|
|
|
models/losses.py
ADDED
|
@@ -0,0 +1,233 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torchvision.models as models
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
|
| 6 |
+
class Vgg19(torch.nn.Module):
|
| 7 |
+
def __init__(self, requires_grad=False):
|
| 8 |
+
super(Vgg19, self).__init__()
|
| 9 |
+
vgg_pretrained_features = vgg19(pretrained=True).features
|
| 10 |
+
self.slice1 = torch.nn.Sequential()
|
| 11 |
+
|
| 12 |
+
for x in range(12):
|
| 13 |
+
self.slice1.add_module(str(x), vgg_pretrained_features[x].eval())
|
| 14 |
+
|
| 15 |
+
if not requires_grad:
|
| 16 |
+
for param in self.parameters():
|
| 17 |
+
param.requires_grad = False
|
| 18 |
+
|
| 19 |
+
def forward(self, X):
|
| 20 |
+
h_relu1 = self.slice1(X)
|
| 21 |
+
return h_relu1
|
| 22 |
+
|
| 23 |
+
class ContrastLoss(nn.Module):
|
| 24 |
+
def __init__(self, ablation=False):
|
| 25 |
+
|
| 26 |
+
super(ContrastLoss, self).__init__()
|
| 27 |
+
self.vgg = Vgg19().cuda()
|
| 28 |
+
self.l1 = nn.L1Loss()
|
| 29 |
+
self.ab = ablation
|
| 30 |
+
self.down_sample_4 = nn.Upsample(scale_factor=1 / 4, mode='bilinear')
|
| 31 |
+
def forward(self, restore, sharp, blur):
|
| 32 |
+
B, C, H, W = restore.size()
|
| 33 |
+
restore_vgg, sharp_vgg, blur_vgg = self.vgg(restore), self.vgg(sharp), self.vgg(blur)
|
| 34 |
+
|
| 35 |
+
# filter out sharp regions
|
| 36 |
+
threshold = 0.01
|
| 37 |
+
mask = torch.mean(torch.abs(sharp-blur), dim=1).view(B, 1, H, W)
|
| 38 |
+
mask[mask <= threshold] = 0
|
| 39 |
+
mask[mask > threshold] = 1
|
| 40 |
+
mask = self.down_sample_4(mask)
|
| 41 |
+
d_ap = torch.mean(torch.abs((restore_vgg - sharp_vgg.detach())), dim=1).view(B, 1, H//4, W//4)
|
| 42 |
+
d_an = torch.mean(torch.abs((restore_vgg - blur_vgg.detach())), dim=1).view(B, 1, H//4, W//4)
|
| 43 |
+
mask_size = torch.sum(mask)
|
| 44 |
+
contrastive = torch.sum((d_ap / (d_an + 1e-7)) * mask) / mask_size
|
| 45 |
+
|
| 46 |
+
return contrastive
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class ContrastLoss_Ori(nn.Module):
|
| 50 |
+
def __init__(self, ablation=False):
|
| 51 |
+
super(ContrastLoss_Ori, self).__init__()
|
| 52 |
+
self.vgg = Vgg19().cuda()
|
| 53 |
+
self.l1 = nn.L1Loss()
|
| 54 |
+
self.ab = ablation
|
| 55 |
+
|
| 56 |
+
def forward(self, restore, sharp, blur):
|
| 57 |
+
|
| 58 |
+
restore_vgg, sharp_vgg, blur_vgg = self.vgg(restore), self.vgg(sharp), self.vgg(blur)
|
| 59 |
+
d_ap = self.l1(restore_vgg, sharp_vgg.detach())
|
| 60 |
+
d_an = self.l1(restore_vgg, blur_vgg.detach())
|
| 61 |
+
contrastive_loss = d_ap / (d_an + 1e-7)
|
| 62 |
+
|
| 63 |
+
return contrastive_loss
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
class CharbonnierLoss(nn.Module):
|
| 67 |
+
"""Charbonnier Loss (L1)"""
|
| 68 |
+
|
| 69 |
+
def __init__(self, eps=1e-3):
|
| 70 |
+
super(CharbonnierLoss, self).__init__()
|
| 71 |
+
self.eps = eps
|
| 72 |
+
|
| 73 |
+
def forward(self, x, y):
|
| 74 |
+
diff = x - y
|
| 75 |
+
# loss = torch.sum(torch.sqrt(diff * diff + self.eps))
|
| 76 |
+
loss = torch.mean(torch.sqrt((diff * diff) + (self.eps * self.eps)))
|
| 77 |
+
return loss
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
class EdgeLoss(nn.Module):
|
| 81 |
+
def __init__(self):
|
| 82 |
+
super(EdgeLoss, self).__init__()
|
| 83 |
+
k = torch.Tensor([[.05, .25, .4, .25, .05]])
|
| 84 |
+
self.kernel = torch.matmul(k.t(), k).unsqueeze(0).repeat(3, 1, 1, 1)
|
| 85 |
+
if torch.cuda.is_available():
|
| 86 |
+
self.kernel = self.kernel.cuda()
|
| 87 |
+
self.loss = CharbonnierLoss()
|
| 88 |
+
|
| 89 |
+
def conv_gauss(self, img):
|
| 90 |
+
n_channels, _, kw, kh = self.kernel.shape
|
| 91 |
+
img = F.pad(img, (kw // 2, kh // 2, kw // 2, kh // 2), mode='replicate')
|
| 92 |
+
return F.conv2d(img, self.kernel, groups=n_channels)
|
| 93 |
+
|
| 94 |
+
def laplacian_kernel(self, current):
|
| 95 |
+
filtered = self.conv_gauss(current) # filter
|
| 96 |
+
down = filtered[:, :, ::2, ::2] # downsample
|
| 97 |
+
new_filter = torch.zeros_like(filtered)
|
| 98 |
+
new_filter[:, :, ::2, ::2] = down * 4 # upsample
|
| 99 |
+
filtered = self.conv_gauss(new_filter) # filter
|
| 100 |
+
diff = current - filtered
|
| 101 |
+
return diff
|
| 102 |
+
|
| 103 |
+
def forward(self, x, y):
|
| 104 |
+
# x = torch.clamp(x + 0.5, min = 0,max = 1)
|
| 105 |
+
# y = torch.clamp(y + 0.5, min = 0,max = 1)
|
| 106 |
+
loss = self.loss(self.laplacian_kernel(x), self.laplacian_kernel(y))
|
| 107 |
+
return loss
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
class Stripformer_Loss(nn.Module):
|
| 111 |
+
|
| 112 |
+
def __init__(self, ):
|
| 113 |
+
super(Stripformer_Loss, self).__init__()
|
| 114 |
+
|
| 115 |
+
self.char = CharbonnierLoss()
|
| 116 |
+
self.edge = EdgeLoss()
|
| 117 |
+
self.contrastive = ContrastLoss()
|
| 118 |
+
|
| 119 |
+
def forward(self, restore, sharp, blur):
|
| 120 |
+
char = self.char(restore, sharp)
|
| 121 |
+
edge = 0.05 * self.edge(restore, sharp)
|
| 122 |
+
contrastive = 0.0005 * self.contrastive(restore, sharp, blur)
|
| 123 |
+
loss = char + edge + contrastive
|
| 124 |
+
return loss
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def get_loss(model):
|
| 128 |
+
if model['content_loss'] == 'Stripformer_Loss':
|
| 129 |
+
content_loss = Stripformer_Loss()
|
| 130 |
+
elif model['content_loss'] == 'CharbonnierLoss':
|
| 131 |
+
content_loss = CharbonnierLoss()
|
| 132 |
+
else:
|
| 133 |
+
raise ValueError("ContentLoss [%s] not recognized." % model['content_loss'])
|
| 134 |
+
return content_loss
|
| 135 |
+
|
| 136 |
+
from typing import Union, List, Dict, Any, cast
|
| 137 |
+
|
| 138 |
+
import torch
|
| 139 |
+
import torch.nn as nn
|
| 140 |
+
|
| 141 |
+
class VGG(nn.Module):
|
| 142 |
+
def __init__(
|
| 143 |
+
self, features: nn.Module, num_classes: int = 1000, init_weights: bool = True, dropout: float = 0.5
|
| 144 |
+
) -> None:
|
| 145 |
+
super().__init__()
|
| 146 |
+
self.features = features
|
| 147 |
+
self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
|
| 148 |
+
self.classifier = nn.Sequential(
|
| 149 |
+
nn.Linear(512 * 7 * 7, 4096),
|
| 150 |
+
nn.ReLU(True),
|
| 151 |
+
nn.Dropout(p=dropout),
|
| 152 |
+
nn.Linear(4096, 4096),
|
| 153 |
+
nn.ReLU(True),
|
| 154 |
+
nn.Dropout(p=dropout),
|
| 155 |
+
nn.Linear(4096, num_classes),
|
| 156 |
+
)
|
| 157 |
+
if init_weights:
|
| 158 |
+
for m in self.modules():
|
| 159 |
+
if isinstance(m, nn.Conv2d):
|
| 160 |
+
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
|
| 161 |
+
if m.bias is not None:
|
| 162 |
+
nn.init.constant_(m.bias, 0)
|
| 163 |
+
elif isinstance(m, nn.BatchNorm2d):
|
| 164 |
+
nn.init.constant_(m.weight, 1)
|
| 165 |
+
nn.init.constant_(m.bias, 0)
|
| 166 |
+
elif isinstance(m, nn.Linear):
|
| 167 |
+
nn.init.normal_(m.weight, 0, 0.01)
|
| 168 |
+
nn.init.constant_(m.bias, 0)
|
| 169 |
+
|
| 170 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 171 |
+
x = self.features(x)
|
| 172 |
+
x = self.avgpool(x)
|
| 173 |
+
x = torch.flatten(x, 1)
|
| 174 |
+
x = self.classifier(x)
|
| 175 |
+
return x
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
def make_layers(cfg: List[Union[str, int]], batch_norm: bool = False) -> nn.Sequential:
|
| 179 |
+
layers: List[nn.Module] = []
|
| 180 |
+
in_channels = 3
|
| 181 |
+
for v in cfg:
|
| 182 |
+
if v == "M":
|
| 183 |
+
layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
|
| 184 |
+
else:
|
| 185 |
+
v = cast(int, v)
|
| 186 |
+
conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
|
| 187 |
+
if batch_norm:
|
| 188 |
+
layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
|
| 189 |
+
else:
|
| 190 |
+
layers += [conv2d, nn.ReLU(inplace=True)]
|
| 191 |
+
in_channels = v
|
| 192 |
+
return nn.Sequential(*layers)
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
cfgs: Dict[str, List[Union[str, int]]] = {
|
| 196 |
+
"A": [64, "M", 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"],
|
| 197 |
+
"B": [64, 64, "M", 128, 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"],
|
| 198 |
+
"D": [64, 64, "M", 128, 128, "M", 256, 256, 256, "M", 512, 512, 512, "M", 512, 512, 512, "M"],
|
| 199 |
+
"E": [64, 64, "M", 128, 128, "M", 256, 256, 256, 256, "M", 512, 512, 512, 512, "M", 512, 512, 512, 512, "M"],
|
| 200 |
+
}
|
| 201 |
+
|
| 202 |
+
def _vgg(arch: str, cfg: str, batch_norm: bool, pretrained: bool, progress: bool, **kwargs: Any) -> VGG:
|
| 203 |
+
if pretrained:
|
| 204 |
+
kwargs["init_weights"] = False
|
| 205 |
+
model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm), **kwargs)
|
| 206 |
+
if pretrained:
|
| 207 |
+
state_dict = torch.load("/home/hanzhou1996/low-level/StripMamba/models/vgg19-dcbb9e9d.pth") # change the path to vgg19.pth
|
| 208 |
+
model.load_state_dict(state_dict)
|
| 209 |
+
return model
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
def vgg19(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG:
|
| 213 |
+
r"""VGG 19-layer model (configuration "E")
|
| 214 |
+
`"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_.
|
| 215 |
+
The required minimum input size of the model is 32x32.
|
| 216 |
+
|
| 217 |
+
Args:
|
| 218 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
| 219 |
+
progress (bool): If True, displays a progress bar of the download to stderr
|
| 220 |
+
"""
|
| 221 |
+
return _vgg("vgg19", "E", False, pretrained, progress, **kwargs)
|
| 222 |
+
"""
|
| 223 |
+
if __name__ == "__main__":
|
| 224 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 225 |
+
#model = VGG(make_layers(cfgs["E"], batch_norm=False)).to(device)
|
| 226 |
+
#model.load_state_dict(torch.load("models/vgg19-dcbb9e9d.pth"))
|
| 227 |
+
model = vgg19().to(device)
|
| 228 |
+
print(model.features)
|
| 229 |
+
BATCH_SIZE = 3
|
| 230 |
+
x = torch.randn(3, 3, 224, 224).to(device)
|
| 231 |
+
assert model(x).shape == torch.Size([BATCH_SIZE, 1000])
|
| 232 |
+
print(model(x).shape)
|
| 233 |
+
"""
|
models/models.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
#from skimage.measure import compare_ssim as SSIM
|
| 4 |
+
from skimage.metrics import structural_similarity as SSIM
|
| 5 |
+
|
| 6 |
+
from util.metrics import PSNR
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class DeblurModel(nn.Module):
|
| 10 |
+
def __init__(self):
|
| 11 |
+
super(DeblurModel, self).__init__()
|
| 12 |
+
|
| 13 |
+
def get_input(self, data):
|
| 14 |
+
img = data['a']
|
| 15 |
+
inputs = img
|
| 16 |
+
targets = data['b']
|
| 17 |
+
inputs, targets = inputs.cuda(), targets.cuda()
|
| 18 |
+
return inputs, targets
|
| 19 |
+
|
| 20 |
+
def tensor2im(self, image_tensor, imtype=np.uint8):
|
| 21 |
+
image_numpy = image_tensor[0].cpu().float().numpy()
|
| 22 |
+
image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 0.5) * 255.0
|
| 23 |
+
return image_numpy
|
| 24 |
+
|
| 25 |
+
def get_images_and_metrics(self, inp, output, target) -> (float, float, np.ndarray):
|
| 26 |
+
inp = self.tensor2im(inp)
|
| 27 |
+
fake = self.tensor2im(output.data)
|
| 28 |
+
real = self.tensor2im(target.data)
|
| 29 |
+
psnr = PSNR(fake, real)
|
| 30 |
+
ssim = SSIM(fake.astype('uint8'), real.astype('uint8'), channel_axis=2)
|
| 31 |
+
vis_img = np.hstack((inp, fake, real))
|
| 32 |
+
return psnr, ssim, vis_img
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def get_model(model_config):
|
| 36 |
+
return DeblurModel()
|
models/networks.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
from models.XYScanNet import XYScanNet
|
| 3 |
+
from models.XYScanNetP import XYScanNetP
|
| 4 |
+
|
| 5 |
+
def get_generator(model_config):
|
| 6 |
+
generator_name = model_config['g_name']
|
| 7 |
+
if generator_name == 'XYScanNet':
|
| 8 |
+
model_g = XYScanNet()
|
| 9 |
+
elif generator_name == 'XYScanNetP':
|
| 10 |
+
model_g = XYScanNetP()
|
| 11 |
+
else:
|
| 12 |
+
raise ValueError("Generator Network [%s] not recognized." % generator_name)
|
| 13 |
+
return nn.DataParallel(model_g)
|
| 14 |
+
|
| 15 |
+
def get_nets(model_config):
|
| 16 |
+
return get_generator(model_config)
|
models/sota/FFTformer.py
ADDED
|
@@ -0,0 +1,324 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import numbers
|
| 5 |
+
from einops import rearrange
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def to_3d(x):
|
| 9 |
+
return rearrange(x, 'b c h w -> b (h w) c')
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def to_4d(x, h, w):
|
| 13 |
+
return rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class BiasFree_LayerNorm(nn.Module):
|
| 17 |
+
def __init__(self, normalized_shape):
|
| 18 |
+
super(BiasFree_LayerNorm, self).__init__()
|
| 19 |
+
if isinstance(normalized_shape, numbers.Integral):
|
| 20 |
+
normalized_shape = (normalized_shape,)
|
| 21 |
+
normalized_shape = torch.Size(normalized_shape)
|
| 22 |
+
|
| 23 |
+
assert len(normalized_shape) == 1
|
| 24 |
+
|
| 25 |
+
self.weight = nn.Parameter(torch.ones(normalized_shape))
|
| 26 |
+
self.normalized_shape = normalized_shape
|
| 27 |
+
|
| 28 |
+
def forward(self, x):
|
| 29 |
+
sigma = x.var(-1, keepdim=True, unbiased=False)
|
| 30 |
+
return x / torch.sqrt(sigma + 1e-5) * self.weight
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class WithBias_LayerNorm(nn.Module):
|
| 34 |
+
def __init__(self, normalized_shape):
|
| 35 |
+
super(WithBias_LayerNorm, self).__init__()
|
| 36 |
+
if isinstance(normalized_shape, numbers.Integral):
|
| 37 |
+
normalized_shape = (normalized_shape,)
|
| 38 |
+
normalized_shape = torch.Size(normalized_shape)
|
| 39 |
+
|
| 40 |
+
assert len(normalized_shape) == 1
|
| 41 |
+
|
| 42 |
+
self.weight = nn.Parameter(torch.ones(normalized_shape))
|
| 43 |
+
self.bias = nn.Parameter(torch.zeros(normalized_shape))
|
| 44 |
+
self.normalized_shape = normalized_shape
|
| 45 |
+
|
| 46 |
+
def forward(self, x):
|
| 47 |
+
mu = x.mean(-1, keepdim=True)
|
| 48 |
+
sigma = x.var(-1, keepdim=True, unbiased=False)
|
| 49 |
+
return (x - mu) / torch.sqrt(sigma + 1e-5) * self.weight + self.bias
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class LayerNorm(nn.Module):
|
| 53 |
+
def __init__(self, dim, LayerNorm_type):
|
| 54 |
+
super(LayerNorm, self).__init__()
|
| 55 |
+
if LayerNorm_type == 'BiasFree':
|
| 56 |
+
self.body = BiasFree_LayerNorm(dim)
|
| 57 |
+
else:
|
| 58 |
+
self.body = WithBias_LayerNorm(dim)
|
| 59 |
+
|
| 60 |
+
def forward(self, x):
|
| 61 |
+
h, w = x.shape[-2:]
|
| 62 |
+
return to_4d(self.body(to_3d(x)), h, w)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class DFFN(nn.Module):
|
| 66 |
+
def __init__(self, dim, ffn_expansion_factor, bias):
|
| 67 |
+
|
| 68 |
+
super(DFFN, self).__init__()
|
| 69 |
+
|
| 70 |
+
hidden_features = int(dim * ffn_expansion_factor)
|
| 71 |
+
|
| 72 |
+
self.patch_size = 8
|
| 73 |
+
|
| 74 |
+
self.dim = dim
|
| 75 |
+
self.project_in = nn.Conv2d(dim, hidden_features * 2, kernel_size=1, bias=bias)
|
| 76 |
+
|
| 77 |
+
self.dwconv = nn.Conv2d(hidden_features * 2, hidden_features * 2, kernel_size=3, stride=1, padding=1,
|
| 78 |
+
groups=hidden_features * 2, bias=bias)
|
| 79 |
+
|
| 80 |
+
self.fft = nn.Parameter(torch.ones((hidden_features * 2, 1, 1, self.patch_size, self.patch_size // 2 + 1)))
|
| 81 |
+
self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias)
|
| 82 |
+
|
| 83 |
+
def forward(self, x):
|
| 84 |
+
x = self.project_in(x)
|
| 85 |
+
x_patch = rearrange(x, 'b c (h patch1) (w patch2) -> b c h w patch1 patch2', patch1=self.patch_size,
|
| 86 |
+
patch2=self.patch_size)
|
| 87 |
+
x_patch_fft = torch.fft.rfft2(x_patch.float())
|
| 88 |
+
x_patch_fft = x_patch_fft * self.fft
|
| 89 |
+
x_patch = torch.fft.irfft2(x_patch_fft, s=(self.patch_size, self.patch_size))
|
| 90 |
+
x = rearrange(x_patch, 'b c h w patch1 patch2 -> b c (h patch1) (w patch2)', patch1=self.patch_size,
|
| 91 |
+
patch2=self.patch_size)
|
| 92 |
+
x1, x2 = self.dwconv(x).chunk(2, dim=1)
|
| 93 |
+
|
| 94 |
+
x = F.gelu(x1) * x2
|
| 95 |
+
x = self.project_out(x)
|
| 96 |
+
return x
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
class FSAS(nn.Module):
|
| 100 |
+
def __init__(self, dim, bias):
|
| 101 |
+
super(FSAS, self).__init__()
|
| 102 |
+
|
| 103 |
+
self.to_hidden = nn.Conv2d(dim, dim * 6, kernel_size=1, bias=bias)
|
| 104 |
+
self.to_hidden_dw = nn.Conv2d(dim * 6, dim * 6, kernel_size=3, stride=1, padding=1, groups=dim * 6, bias=bias)
|
| 105 |
+
|
| 106 |
+
self.project_out = nn.Conv2d(dim * 2, dim, kernel_size=1, bias=bias)
|
| 107 |
+
|
| 108 |
+
self.norm = LayerNorm(dim * 2, LayerNorm_type='WithBias')
|
| 109 |
+
|
| 110 |
+
self.patch_size = 8
|
| 111 |
+
|
| 112 |
+
def forward(self, x):
|
| 113 |
+
hidden = self.to_hidden(x)
|
| 114 |
+
|
| 115 |
+
q, k, v = self.to_hidden_dw(hidden).chunk(3, dim=1)
|
| 116 |
+
|
| 117 |
+
q_patch = rearrange(q, 'b c (h patch1) (w patch2) -> b c h w patch1 patch2', patch1=self.patch_size,
|
| 118 |
+
patch2=self.patch_size)
|
| 119 |
+
k_patch = rearrange(k, 'b c (h patch1) (w patch2) -> b c h w patch1 patch2', patch1=self.patch_size,
|
| 120 |
+
patch2=self.patch_size)
|
| 121 |
+
q_fft = torch.fft.rfft2(q_patch.float())
|
| 122 |
+
k_fft = torch.fft.rfft2(k_patch.float())
|
| 123 |
+
|
| 124 |
+
out = q_fft * k_fft
|
| 125 |
+
out = torch.fft.irfft2(out, s=(self.patch_size, self.patch_size))
|
| 126 |
+
out = rearrange(out, 'b c h w patch1 patch2 -> b c (h patch1) (w patch2)', patch1=self.patch_size,
|
| 127 |
+
patch2=self.patch_size)
|
| 128 |
+
|
| 129 |
+
out = self.norm(out)
|
| 130 |
+
|
| 131 |
+
output = v * out
|
| 132 |
+
output = self.project_out(output)
|
| 133 |
+
|
| 134 |
+
return output
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
##########################################################################
|
| 138 |
+
class TransformerBlock(nn.Module):
|
| 139 |
+
def __init__(self, dim, ffn_expansion_factor=2.66, bias=False, LayerNorm_type='WithBias', att=False):
|
| 140 |
+
super(TransformerBlock, self).__init__()
|
| 141 |
+
|
| 142 |
+
self.att = att
|
| 143 |
+
if self.att:
|
| 144 |
+
self.norm1 = LayerNorm(dim, LayerNorm_type)
|
| 145 |
+
self.attn = FSAS(dim, bias)
|
| 146 |
+
|
| 147 |
+
self.norm2 = LayerNorm(dim, LayerNorm_type)
|
| 148 |
+
self.ffn = DFFN(dim, ffn_expansion_factor, bias)
|
| 149 |
+
|
| 150 |
+
def forward(self, x):
|
| 151 |
+
if self.att:
|
| 152 |
+
x = x + self.attn(self.norm1(x))
|
| 153 |
+
|
| 154 |
+
x = x + self.ffn(self.norm2(x))
|
| 155 |
+
|
| 156 |
+
return x
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
class Fuse(nn.Module):
|
| 160 |
+
def __init__(self, n_feat):
|
| 161 |
+
super(Fuse, self).__init__()
|
| 162 |
+
self.n_feat = n_feat
|
| 163 |
+
self.att_channel = TransformerBlock(dim=n_feat * 2)
|
| 164 |
+
|
| 165 |
+
self.conv = nn.Conv2d(n_feat * 2, n_feat * 2, 1, 1, 0)
|
| 166 |
+
self.conv2 = nn.Conv2d(n_feat * 2, n_feat * 2, 1, 1, 0)
|
| 167 |
+
|
| 168 |
+
def forward(self, enc, dnc):
|
| 169 |
+
x = self.conv(torch.cat((enc, dnc), dim=1))
|
| 170 |
+
x = self.att_channel(x)
|
| 171 |
+
x = self.conv2(x)
|
| 172 |
+
e, d = torch.split(x, [self.n_feat, self.n_feat], dim=1)
|
| 173 |
+
output = e + d
|
| 174 |
+
|
| 175 |
+
return output
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
##########################################################################
|
| 179 |
+
## Overlapped image patch embedding with 3x3 Conv
|
| 180 |
+
class OverlapPatchEmbed(nn.Module):
|
| 181 |
+
def __init__(self, in_c=3, embed_dim=48, bias=False):
|
| 182 |
+
super(OverlapPatchEmbed, self).__init__()
|
| 183 |
+
|
| 184 |
+
self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=3, stride=1, padding=1, bias=bias)
|
| 185 |
+
|
| 186 |
+
def forward(self, x):
|
| 187 |
+
x = self.proj(x)
|
| 188 |
+
|
| 189 |
+
return x
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
##########################################################################
|
| 193 |
+
## Resizing modules
|
| 194 |
+
class Downsample(nn.Module):
|
| 195 |
+
def __init__(self, n_feat):
|
| 196 |
+
super(Downsample, self).__init__()
|
| 197 |
+
|
| 198 |
+
self.body = nn.Sequential(nn.Upsample(scale_factor=0.5, mode='bilinear', align_corners=False),
|
| 199 |
+
nn.Conv2d(n_feat, n_feat * 2, 3, stride=1, padding=1, bias=False))
|
| 200 |
+
|
| 201 |
+
def forward(self, x):
|
| 202 |
+
return self.body(x)
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
class Upsample(nn.Module):
|
| 206 |
+
def __init__(self, n_feat):
|
| 207 |
+
super(Upsample, self).__init__()
|
| 208 |
+
|
| 209 |
+
self.body = nn.Sequential(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
|
| 210 |
+
nn.Conv2d(n_feat, n_feat // 2, 3, stride=1, padding=1, bias=False))
|
| 211 |
+
|
| 212 |
+
def forward(self, x):
|
| 213 |
+
return self.body(x)
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
##########################################################################
|
| 217 |
+
##---------- FFTformer -----------------------
|
| 218 |
+
class fftformer(nn.Module):
|
| 219 |
+
def __init__(self,
|
| 220 |
+
inp_channels=3,
|
| 221 |
+
out_channels=3,
|
| 222 |
+
dim=8,
|
| 223 |
+
num_blocks=[6, 6, 12, 8],
|
| 224 |
+
num_refinement_blocks=4,
|
| 225 |
+
ffn_expansion_factor=3,
|
| 226 |
+
bias=False,
|
| 227 |
+
):
|
| 228 |
+
super(fftformer, self).__init__()
|
| 229 |
+
|
| 230 |
+
self.patch_embed = OverlapPatchEmbed(inp_channels, dim)
|
| 231 |
+
|
| 232 |
+
self.encoder_level1 = nn.Sequential(*[
|
| 233 |
+
TransformerBlock(dim=dim, ffn_expansion_factor=ffn_expansion_factor, bias=bias) for i in
|
| 234 |
+
range(num_blocks[0])])
|
| 235 |
+
|
| 236 |
+
self.down1_2 = Downsample(dim)
|
| 237 |
+
self.encoder_level2 = nn.Sequential(*[
|
| 238 |
+
TransformerBlock(dim=int(dim * 2 ** 1), ffn_expansion_factor=ffn_expansion_factor,
|
| 239 |
+
bias=bias) for i in range(num_blocks[1])])
|
| 240 |
+
|
| 241 |
+
self.down2_3 = Downsample(int(dim * 2 ** 1))
|
| 242 |
+
self.encoder_level3 = nn.Sequential(*[
|
| 243 |
+
TransformerBlock(dim=int(dim * 2 ** 2), ffn_expansion_factor=ffn_expansion_factor,
|
| 244 |
+
bias=bias) for i in range(num_blocks[2])])
|
| 245 |
+
|
| 246 |
+
self.decoder_level3 = nn.Sequential(*[
|
| 247 |
+
TransformerBlock(dim=int(dim * 2 ** 2), ffn_expansion_factor=ffn_expansion_factor,
|
| 248 |
+
bias=bias, att=True) for i in range(num_blocks[2])])
|
| 249 |
+
|
| 250 |
+
self.up3_2 = Upsample(int(dim * 2 ** 2))
|
| 251 |
+
self.reduce_chan_level2 = nn.Conv2d(int(dim * 2 ** 2), int(dim * 2 ** 1), kernel_size=1, bias=bias)
|
| 252 |
+
self.decoder_level2 = nn.Sequential(*[
|
| 253 |
+
TransformerBlock(dim=int(dim * 2 ** 1), ffn_expansion_factor=ffn_expansion_factor,
|
| 254 |
+
bias=bias, att=True) for i in range(num_blocks[1])])
|
| 255 |
+
|
| 256 |
+
self.up2_1 = Upsample(int(dim * 2 ** 1))
|
| 257 |
+
|
| 258 |
+
self.decoder_level1 = nn.Sequential(*[
|
| 259 |
+
TransformerBlock(dim=int(dim), ffn_expansion_factor=ffn_expansion_factor,
|
| 260 |
+
bias=bias, att=True) for i in range(num_blocks[0])])
|
| 261 |
+
|
| 262 |
+
self.refinement = nn.Sequential(*[
|
| 263 |
+
TransformerBlock(dim=int(dim), ffn_expansion_factor=ffn_expansion_factor,
|
| 264 |
+
bias=bias, att=True) for i in range(num_refinement_blocks)])
|
| 265 |
+
|
| 266 |
+
self.fuse2 = Fuse(dim * 2)
|
| 267 |
+
self.fuse1 = Fuse(dim)
|
| 268 |
+
self.output = nn.Conv2d(int(dim), out_channels, kernel_size=3, stride=1, padding=1, bias=bias)
|
| 269 |
+
|
| 270 |
+
def forward(self, inp_img):
|
| 271 |
+
inp_enc_level1 = self.patch_embed(inp_img)
|
| 272 |
+
out_enc_level1 = self.encoder_level1(inp_enc_level1)
|
| 273 |
+
|
| 274 |
+
inp_enc_level2 = self.down1_2(out_enc_level1)
|
| 275 |
+
out_enc_level2 = self.encoder_level2(inp_enc_level2)
|
| 276 |
+
|
| 277 |
+
inp_enc_level3 = self.down2_3(out_enc_level2)
|
| 278 |
+
out_enc_level3 = self.encoder_level3(inp_enc_level3)
|
| 279 |
+
|
| 280 |
+
out_dec_level3 = self.decoder_level3(out_enc_level3)
|
| 281 |
+
|
| 282 |
+
inp_dec_level2 = self.up3_2(out_dec_level3)
|
| 283 |
+
|
| 284 |
+
inp_dec_level2 = self.fuse2(inp_dec_level2, out_enc_level2)
|
| 285 |
+
|
| 286 |
+
out_dec_level2 = self.decoder_level2(inp_dec_level2)
|
| 287 |
+
|
| 288 |
+
inp_dec_level1 = self.up2_1(out_dec_level2)
|
| 289 |
+
|
| 290 |
+
inp_dec_level1 = self.fuse1(inp_dec_level1, out_enc_level1)
|
| 291 |
+
out_dec_level1 = self.decoder_level1(inp_dec_level1)
|
| 292 |
+
|
| 293 |
+
out_dec_level1 = self.refinement(out_dec_level1)
|
| 294 |
+
|
| 295 |
+
out_dec_level1 = self.output(out_dec_level1) + inp_img
|
| 296 |
+
|
| 297 |
+
return out_dec_level1
|
| 298 |
+
|
| 299 |
+
#"""
|
| 300 |
+
import time
|
| 301 |
+
start_time = time.time()
|
| 302 |
+
inp = torch.randn(1, 3, 512, 512).cuda()#.to(dtype=torch.float16)
|
| 303 |
+
model = fftformer().cuda()#.to(dtype=torch.float16)
|
| 304 |
+
out = model(inp)
|
| 305 |
+
print(out.shape)
|
| 306 |
+
print("--- %s seconds ---" % (time.time() - start_time))
|
| 307 |
+
pytorch_total_params = sum(p.numel() for p in model.parameters())
|
| 308 |
+
print("--- {num} parameters ---".format(num = pytorch_total_params))
|
| 309 |
+
pytorch_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 310 |
+
print("--- {num} trainable parameters ---".format(num = pytorch_trainable_params))
|
| 311 |
+
gpu_memmem_usage_bytes = torch.cuda.max_memory_allocated()
|
| 312 |
+
print(gpu_memmem_usage_bytes / 1024 / 1024 / 1024) # 64: 1.32 128: 4.94 256: 19.12; 512: OOM
|
| 313 |
+
#"""
|
| 314 |
+
"""
|
| 315 |
+
import torch
|
| 316 |
+
from ptflops import get_model_complexity_info
|
| 317 |
+
|
| 318 |
+
with torch.cuda.device(0):
|
| 319 |
+
net = model
|
| 320 |
+
macs, params = get_model_complexity_info(net, (3, 256, 256), as_strings=True,
|
| 321 |
+
print_per_layer_stat=True, verbose=True)
|
| 322 |
+
print('{:<30} {:<8}'.format('Computational complexity: ', macs)) # 31.97 GMac
|
| 323 |
+
print('{:<30} {:<8}'.format('Number of parameters: ', params)) # 8.37 M
|
| 324 |
+
"""
|
models/sota/Restormer.py
ADDED
|
@@ -0,0 +1,340 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
## Restormer: Efficient Transformer for High-Resolution Image Restoration
|
| 2 |
+
## Syed Waqas Zamir, Aditya Arora, Salman Khan, Munawar Hayat, Fahad Shahbaz Khan, and Ming-Hsuan Yang
|
| 3 |
+
## https://arxiv.org/abs/2111.09881
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
from pdb import set_trace as stx
|
| 10 |
+
import numbers
|
| 11 |
+
|
| 12 |
+
from einops import rearrange
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
##########################################################################
|
| 17 |
+
## Layer Norm
|
| 18 |
+
|
| 19 |
+
def to_3d(x):
|
| 20 |
+
return rearrange(x, 'b c h w -> b (h w) c')
|
| 21 |
+
|
| 22 |
+
def to_4d(x,h,w):
|
| 23 |
+
return rearrange(x, 'b (h w) c -> b c h w',h=h,w=w)
|
| 24 |
+
|
| 25 |
+
class BiasFree_LayerNorm(nn.Module):
|
| 26 |
+
def __init__(self, normalized_shape):
|
| 27 |
+
super(BiasFree_LayerNorm, self).__init__()
|
| 28 |
+
if isinstance(normalized_shape, numbers.Integral):
|
| 29 |
+
normalized_shape = (normalized_shape,)
|
| 30 |
+
normalized_shape = torch.Size(normalized_shape)
|
| 31 |
+
|
| 32 |
+
assert len(normalized_shape) == 1
|
| 33 |
+
|
| 34 |
+
self.weight = nn.Parameter(torch.ones(normalized_shape))
|
| 35 |
+
self.normalized_shape = normalized_shape
|
| 36 |
+
|
| 37 |
+
def forward(self, x):
|
| 38 |
+
sigma = x.var(-1, keepdim=True, unbiased=False)
|
| 39 |
+
return x / torch.sqrt(sigma+1e-5) * self.weight
|
| 40 |
+
|
| 41 |
+
class WithBias_LayerNorm(nn.Module):
|
| 42 |
+
def __init__(self, normalized_shape):
|
| 43 |
+
super(WithBias_LayerNorm, self).__init__()
|
| 44 |
+
if isinstance(normalized_shape, numbers.Integral):
|
| 45 |
+
normalized_shape = (normalized_shape,)
|
| 46 |
+
normalized_shape = torch.Size(normalized_shape)
|
| 47 |
+
|
| 48 |
+
assert len(normalized_shape) == 1
|
| 49 |
+
|
| 50 |
+
self.weight = nn.Parameter(torch.ones(normalized_shape))
|
| 51 |
+
self.bias = nn.Parameter(torch.zeros(normalized_shape))
|
| 52 |
+
self.normalized_shape = normalized_shape
|
| 53 |
+
|
| 54 |
+
def forward(self, x):
|
| 55 |
+
mu = x.mean(-1, keepdim=True)
|
| 56 |
+
sigma = x.var(-1, keepdim=True, unbiased=False)
|
| 57 |
+
return (x - mu) / torch.sqrt(sigma+1e-5) * self.weight + self.bias
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class LayerNorm(nn.Module):
|
| 61 |
+
def __init__(self, dim, LayerNorm_type):
|
| 62 |
+
super(LayerNorm, self).__init__()
|
| 63 |
+
if LayerNorm_type =='BiasFree':
|
| 64 |
+
self.body = BiasFree_LayerNorm(dim)
|
| 65 |
+
else:
|
| 66 |
+
self.body = WithBias_LayerNorm(dim)
|
| 67 |
+
|
| 68 |
+
def forward(self, x):
|
| 69 |
+
h, w = x.shape[-2:]
|
| 70 |
+
return to_4d(self.body(to_3d(x)), h, w)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
##########################################################################
|
| 75 |
+
## Gated-Dconv Feed-Forward Network (GDFN)
|
| 76 |
+
class FeedForward(nn.Module):
|
| 77 |
+
def __init__(self, dim, ffn_expansion_factor, bias):
|
| 78 |
+
super(FeedForward, self).__init__()
|
| 79 |
+
|
| 80 |
+
hidden_features = int(dim*ffn_expansion_factor)
|
| 81 |
+
|
| 82 |
+
self.project_in = nn.Conv2d(dim, hidden_features*2, kernel_size=1, bias=bias)
|
| 83 |
+
|
| 84 |
+
self.dwconv = nn.Conv2d(hidden_features*2, hidden_features*2, kernel_size=3, stride=1, padding=1, groups=hidden_features*2, bias=bias)
|
| 85 |
+
|
| 86 |
+
self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias)
|
| 87 |
+
|
| 88 |
+
def forward(self, x):
|
| 89 |
+
x = self.project_in(x)
|
| 90 |
+
x1, x2 = self.dwconv(x).chunk(2, dim=1)
|
| 91 |
+
x = F.gelu(x1) * x2
|
| 92 |
+
x = self.project_out(x)
|
| 93 |
+
return x
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
##########################################################################
|
| 98 |
+
## Multi-DConv Head Transposed Self-Attention (MDTA)
|
| 99 |
+
class Attention(nn.Module):
|
| 100 |
+
def __init__(self, dim, num_heads, bias):
|
| 101 |
+
super(Attention, self).__init__()
|
| 102 |
+
self.num_heads = num_heads
|
| 103 |
+
self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))
|
| 104 |
+
|
| 105 |
+
self.qkv = nn.Conv2d(dim, dim*3, kernel_size=1, bias=bias)
|
| 106 |
+
self.qkv_dwconv = nn.Conv2d(dim*3, dim*3, kernel_size=3, stride=1, padding=1, groups=dim*3, bias=bias)
|
| 107 |
+
self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def forward(self, x):
|
| 112 |
+
b,c,h,w = x.shape
|
| 113 |
+
|
| 114 |
+
qkv = self.qkv_dwconv(self.qkv(x))
|
| 115 |
+
q,k,v = qkv.chunk(3, dim=1)
|
| 116 |
+
|
| 117 |
+
q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
|
| 118 |
+
k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
|
| 119 |
+
v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
|
| 120 |
+
|
| 121 |
+
q = torch.nn.functional.normalize(q, dim=-1)
|
| 122 |
+
k = torch.nn.functional.normalize(k, dim=-1)
|
| 123 |
+
|
| 124 |
+
attn = (q @ k.transpose(-2, -1)) * self.temperature
|
| 125 |
+
attn = attn.softmax(dim=-1)
|
| 126 |
+
|
| 127 |
+
out = (attn @ v)
|
| 128 |
+
|
| 129 |
+
out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w)
|
| 130 |
+
|
| 131 |
+
out = self.project_out(out)
|
| 132 |
+
return out
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
##########################################################################
|
| 137 |
+
class TransformerBlock(nn.Module):
|
| 138 |
+
def __init__(self, dim, num_heads, ffn_expansion_factor, bias, LayerNorm_type):
|
| 139 |
+
super(TransformerBlock, self).__init__()
|
| 140 |
+
|
| 141 |
+
self.norm1 = LayerNorm(dim, LayerNorm_type)
|
| 142 |
+
self.attn = Attention(dim, num_heads, bias)
|
| 143 |
+
self.norm2 = LayerNorm(dim, LayerNorm_type)
|
| 144 |
+
self.ffn = FeedForward(dim, ffn_expansion_factor, bias)
|
| 145 |
+
|
| 146 |
+
def forward(self, x):
|
| 147 |
+
x = x + self.attn(self.norm1(x))
|
| 148 |
+
x = x + self.ffn(self.norm2(x))
|
| 149 |
+
|
| 150 |
+
return x
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
##########################################################################
|
| 155 |
+
## Overlapped image patch embedding with 3x3 Conv
|
| 156 |
+
class OverlapPatchEmbed(nn.Module):
|
| 157 |
+
def __init__(self, in_c=3, embed_dim=48, bias=False):
|
| 158 |
+
super(OverlapPatchEmbed, self).__init__()
|
| 159 |
+
|
| 160 |
+
self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=3, stride=1, padding=1, bias=bias)
|
| 161 |
+
|
| 162 |
+
def forward(self, x):
|
| 163 |
+
x = self.proj(x)
|
| 164 |
+
|
| 165 |
+
return x
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
##########################################################################
|
| 170 |
+
## Resizing modules
|
| 171 |
+
class Downsample(nn.Module):
|
| 172 |
+
def __init__(self, n_feat):
|
| 173 |
+
super(Downsample, self).__init__()
|
| 174 |
+
|
| 175 |
+
self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat//2, kernel_size=3, stride=1, padding=1, bias=False),
|
| 176 |
+
nn.PixelUnshuffle(2))
|
| 177 |
+
|
| 178 |
+
def forward(self, x):
|
| 179 |
+
return self.body(x)
|
| 180 |
+
|
| 181 |
+
class Upsample(nn.Module):
|
| 182 |
+
def __init__(self, n_feat):
|
| 183 |
+
super(Upsample, self).__init__()
|
| 184 |
+
|
| 185 |
+
self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat*2, kernel_size=3, stride=1, padding=1, bias=False),
|
| 186 |
+
nn.PixelShuffle(2))
|
| 187 |
+
|
| 188 |
+
def forward(self, x):
|
| 189 |
+
return self.body(x)
|
| 190 |
+
|
| 191 |
+
##########################################################################
|
| 192 |
+
class Strip_VSSB(nn.Module):
|
| 193 |
+
def __init__(self, dim, head_num):
|
| 194 |
+
super(Strip_VSSB, self).__init__()
|
| 195 |
+
|
| 196 |
+
self.intra = TransformerBlock(dim=32, num_heads=head_num, ffn_expansion_factor=2.66, bias=False, LayerNorm_type='WithBias')
|
| 197 |
+
self.inter = TransformerBlock(dim=32, num_heads=head_num, ffn_expansion_factor=2.66, bias=False, LayerNorm_type='WithBias')
|
| 198 |
+
|
| 199 |
+
def forward(self, x):
|
| 200 |
+
x = self.intra(x)
|
| 201 |
+
x = self.inter(x)
|
| 202 |
+
|
| 203 |
+
return x
|
| 204 |
+
|
| 205 |
+
##########################################################################
|
| 206 |
+
##---------- Restormer -----------------------
|
| 207 |
+
class Restormer(nn.Module):
|
| 208 |
+
def __init__(self,
|
| 209 |
+
inp_channels=3,
|
| 210 |
+
out_channels=3,
|
| 211 |
+
dim = 12,
|
| 212 |
+
num_blocks = [4,6,6,8],
|
| 213 |
+
num_refinement_blocks = 4,
|
| 214 |
+
heads = [1,2,4,8],
|
| 215 |
+
ffn_expansion_factor = 2.66,
|
| 216 |
+
bias = False,
|
| 217 |
+
LayerNorm_type = 'WithBias', ## Other option 'BiasFree'
|
| 218 |
+
dual_pixel_task = False ## True for dual-pixel defocus deblurring only. Also set inp_channels=6
|
| 219 |
+
):
|
| 220 |
+
|
| 221 |
+
super(Restormer, self).__init__()
|
| 222 |
+
|
| 223 |
+
self.patch_embed = OverlapPatchEmbed(inp_channels, dim)
|
| 224 |
+
|
| 225 |
+
self.encoder_level1 = nn.Sequential(*[TransformerBlock(dim=dim, num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])])
|
| 226 |
+
|
| 227 |
+
self.down1_2 = Downsample(dim) ## From Level 1 to Level 2
|
| 228 |
+
self.encoder_level2 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[1])])
|
| 229 |
+
|
| 230 |
+
self.down2_3 = Downsample(int(dim*2**1)) ## From Level 2 to Level 3
|
| 231 |
+
self.encoder_level3 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[2])])
|
| 232 |
+
|
| 233 |
+
self.down3_4 = Downsample(int(dim*2**2)) ## From Level 3 to Level 4
|
| 234 |
+
self.latent = nn.Sequential(*[TransformerBlock(dim=int(dim*2**3), num_heads=heads[3], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[3])])
|
| 235 |
+
|
| 236 |
+
self.up4_3 = Upsample(int(dim*2**3)) ## From Level 4 to Level 3
|
| 237 |
+
self.reduce_chan_level3 = nn.Conv2d(int(dim*2**3), int(dim*2**2), kernel_size=1, bias=bias)
|
| 238 |
+
self.decoder_level3 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[2])])
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
self.up3_2 = Upsample(int(dim*2**2)) ## From Level 3 to Level 2
|
| 242 |
+
self.reduce_chan_level2 = nn.Conv2d(int(dim*2**2), int(dim*2**1), kernel_size=1, bias=bias)
|
| 243 |
+
self.decoder_level2 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[1])])
|
| 244 |
+
|
| 245 |
+
self.up2_1 = Upsample(int(dim*2**1)) ## From Level 2 to Level 1 (NO 1x1 conv to reduce channels)
|
| 246 |
+
|
| 247 |
+
self.decoder_level1 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])])
|
| 248 |
+
|
| 249 |
+
self.refinement = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_refinement_blocks)])
|
| 250 |
+
|
| 251 |
+
#### For Dual-Pixel Defocus Deblurring Task ####
|
| 252 |
+
self.dual_pixel_task = dual_pixel_task
|
| 253 |
+
if self.dual_pixel_task:
|
| 254 |
+
self.skip_conv = nn.Conv2d(dim, int(dim*2**1), kernel_size=1, bias=bias)
|
| 255 |
+
###########################
|
| 256 |
+
|
| 257 |
+
self.output = nn.Conv2d(int(dim*2**1), out_channels, kernel_size=3, stride=1, padding=1, bias=bias)
|
| 258 |
+
|
| 259 |
+
def forward(self, inp_img):
|
| 260 |
+
|
| 261 |
+
inp_enc_level1 = self.patch_embed(inp_img)
|
| 262 |
+
out_enc_level1 = self.encoder_level1(inp_enc_level1)
|
| 263 |
+
|
| 264 |
+
inp_enc_level2 = self.down1_2(out_enc_level1)
|
| 265 |
+
out_enc_level2 = self.encoder_level2(inp_enc_level2)
|
| 266 |
+
|
| 267 |
+
inp_enc_level3 = self.down2_3(out_enc_level2)
|
| 268 |
+
out_enc_level3 = self.encoder_level3(inp_enc_level3)
|
| 269 |
+
|
| 270 |
+
inp_enc_level4 = self.down3_4(out_enc_level3)
|
| 271 |
+
latent = self.latent(inp_enc_level4)
|
| 272 |
+
|
| 273 |
+
inp_dec_level3 = self.up4_3(latent)
|
| 274 |
+
inp_dec_level3 = torch.cat([inp_dec_level3, out_enc_level3], 1)
|
| 275 |
+
inp_dec_level3 = self.reduce_chan_level3(inp_dec_level3)
|
| 276 |
+
out_dec_level3 = self.decoder_level3(inp_dec_level3)
|
| 277 |
+
|
| 278 |
+
inp_dec_level2 = self.up3_2(out_dec_level3)
|
| 279 |
+
inp_dec_level2 = torch.cat([inp_dec_level2, out_enc_level2], 1)
|
| 280 |
+
inp_dec_level2 = self.reduce_chan_level2(inp_dec_level2)
|
| 281 |
+
out_dec_level2 = self.decoder_level2(inp_dec_level2)
|
| 282 |
+
|
| 283 |
+
inp_dec_level1 = self.up2_1(out_dec_level2)
|
| 284 |
+
inp_dec_level1 = torch.cat([inp_dec_level1, out_enc_level1], 1)
|
| 285 |
+
out_dec_level1 = self.decoder_level1(inp_dec_level1)
|
| 286 |
+
|
| 287 |
+
out_dec_level1 = self.refinement(out_dec_level1)
|
| 288 |
+
|
| 289 |
+
#### For Dual-Pixel Defocus Deblurring Task ####
|
| 290 |
+
if self.dual_pixel_task:
|
| 291 |
+
out_dec_level1 = out_dec_level1 + self.skip_conv(inp_enc_level1)
|
| 292 |
+
out_dec_level1 = self.output(out_dec_level1)
|
| 293 |
+
###########################
|
| 294 |
+
else:
|
| 295 |
+
out_dec_level1 = self.output(out_dec_level1) + inp_img
|
| 296 |
+
|
| 297 |
+
|
| 298 |
+
return out_dec_level1
|
| 299 |
+
|
| 300 |
+
#"""
|
| 301 |
+
import time
|
| 302 |
+
start_time = time.time()
|
| 303 |
+
inp = torch.randn(1, 3, 256, 256).cuda()#.to(dtype=torch.float16)
|
| 304 |
+
model = Restormer().cuda()#.to(dtype=torch.float16)
|
| 305 |
+
out = model(inp)
|
| 306 |
+
print(out.shape)
|
| 307 |
+
print("--- %s seconds ---" % (time.time() - start_time))
|
| 308 |
+
pytorch_total_params = sum(p.numel() for p in model.parameters())
|
| 309 |
+
print("--- {num} parameters ---".format(num = pytorch_total_params))
|
| 310 |
+
pytorch_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 311 |
+
print("--- {num} trainable parameters ---".format(num = pytorch_trainable_params))
|
| 312 |
+
gpu_memmem_usage_bytes = torch.cuda.max_memory_allocated()
|
| 313 |
+
print(gpu_memmem_usage_bytes / 1024 / 1024 / 1024) # 64: 0.97 128: 3.04 256: 11.93; 512: OOM
|
| 314 |
+
#"""
|
| 315 |
+
"""
|
| 316 |
+
import torch
|
| 317 |
+
from ptflops import get_model_complexity_info
|
| 318 |
+
|
| 319 |
+
with torch.cuda.device(0):
|
| 320 |
+
net = model
|
| 321 |
+
macs, params = get_model_complexity_info(net, (3, 256, 256), as_strings=True,
|
| 322 |
+
print_per_layer_stat=True, verbose=True)
|
| 323 |
+
print('{:<30} {:<8}'.format('Computational complexity: ', macs)) # 31.97 GMac
|
| 324 |
+
print('{:<30} {:<8}'.format('Number of parameters: ', params)) # 8.37 M
|
| 325 |
+
"""
|
| 326 |
+
"""
|
| 327 |
+
import time
|
| 328 |
+
start_time = time.time()
|
| 329 |
+
inp = torch.randn(1, 32, 64, 64).cuda()#.to(dtype=torch.float16)
|
| 330 |
+
model = Strip_VSSB(dim=32, head_num=4).cuda()#.to(dtype=torch.float16)
|
| 331 |
+
out = model(inp)
|
| 332 |
+
print(out.shape)
|
| 333 |
+
print("--- %s seconds ---" % (time.time() - start_time))
|
| 334 |
+
pytorch_total_params = sum(p.numel() for p in model.parameters())
|
| 335 |
+
print("--- {num} parameters ---".format(num = pytorch_total_params))
|
| 336 |
+
pytorch_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 337 |
+
print("--- {num} trainable parameters ---".format(num = pytorch_trainable_params))
|
| 338 |
+
gpu_memmem_usage_bytes = torch.cuda.max_memory_allocated()
|
| 339 |
+
print(gpu_memmem_usage_bytes / 1024 / 1024 / 1024) # 64: 0.16; 128: 0.22; 192: 0.37; 256: 0.56; 512: 2.10;
|
| 340 |
+
"""
|
models/sota/Stripformer.py
ADDED
|
@@ -0,0 +1,429 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import math
|
| 4 |
+
|
| 5 |
+
class Embeddings(nn.Module):
|
| 6 |
+
def __init__(self):
|
| 7 |
+
super(Embeddings, self).__init__()
|
| 8 |
+
|
| 9 |
+
self.activation = nn.LeakyReLU(0.2, True)
|
| 10 |
+
|
| 11 |
+
self.en_layer1_1 = nn.Sequential(
|
| 12 |
+
nn.Conv2d(3, 64, kernel_size=3, padding=1),
|
| 13 |
+
self.activation,
|
| 14 |
+
)
|
| 15 |
+
self.en_layer1_2 = nn.Sequential(
|
| 16 |
+
nn.Conv2d(64, 64, kernel_size=3, padding=1),
|
| 17 |
+
self.activation,
|
| 18 |
+
nn.Conv2d(64, 64, kernel_size=3, padding=1))
|
| 19 |
+
self.en_layer1_3 = nn.Sequential(
|
| 20 |
+
nn.Conv2d(64, 64, kernel_size=3, padding=1),
|
| 21 |
+
self.activation,
|
| 22 |
+
nn.Conv2d(64, 64, kernel_size=3, padding=1))
|
| 23 |
+
self.en_layer1_4 = nn.Sequential(
|
| 24 |
+
nn.Conv2d(64, 64, kernel_size=3, padding=1),
|
| 25 |
+
self.activation,
|
| 26 |
+
nn.Conv2d(64, 64, kernel_size=3, padding=1))
|
| 27 |
+
|
| 28 |
+
self.en_layer2_1 = nn.Sequential(
|
| 29 |
+
nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
|
| 30 |
+
self.activation,
|
| 31 |
+
)
|
| 32 |
+
self.en_layer2_2 = nn.Sequential(
|
| 33 |
+
nn.Conv2d(128, 128, kernel_size=3, padding=1),
|
| 34 |
+
self.activation,
|
| 35 |
+
nn.Conv2d(128, 128, kernel_size=3, padding=1))
|
| 36 |
+
self.en_layer2_3 = nn.Sequential(
|
| 37 |
+
nn.Conv2d(128, 128, kernel_size=3, padding=1),
|
| 38 |
+
self.activation,
|
| 39 |
+
nn.Conv2d(128, 128, kernel_size=3, padding=1))
|
| 40 |
+
self.en_layer2_4 = nn.Sequential(
|
| 41 |
+
nn.Conv2d(128, 128, kernel_size=3, padding=1),
|
| 42 |
+
self.activation,
|
| 43 |
+
nn.Conv2d(128, 128, kernel_size=3, padding=1))
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
self.en_layer3_1 = nn.Sequential(
|
| 47 |
+
nn.Conv2d(128, 320, kernel_size=3, stride=2, padding=1),
|
| 48 |
+
self.activation,
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def forward(self, x):
|
| 53 |
+
|
| 54 |
+
hx = self.en_layer1_1(x)
|
| 55 |
+
hx = self.activation(self.en_layer1_2(hx) + hx)
|
| 56 |
+
hx = self.activation(self.en_layer1_3(hx) + hx)
|
| 57 |
+
hx = self.activation(self.en_layer1_4(hx) + hx)
|
| 58 |
+
residual_1 = hx
|
| 59 |
+
hx = self.en_layer2_1(hx)
|
| 60 |
+
hx = self.activation(self.en_layer2_2(hx) + hx)
|
| 61 |
+
hx = self.activation(self.en_layer2_3(hx) + hx)
|
| 62 |
+
hx = self.activation(self.en_layer2_4(hx) + hx)
|
| 63 |
+
residual_2 = hx
|
| 64 |
+
hx = self.en_layer3_1(hx)
|
| 65 |
+
|
| 66 |
+
return hx, residual_1, residual_2
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
class Embeddings_output(nn.Module):
|
| 70 |
+
def __init__(self):
|
| 71 |
+
super(Embeddings_output, self).__init__()
|
| 72 |
+
|
| 73 |
+
self.activation = nn.LeakyReLU(0.2, True)
|
| 74 |
+
|
| 75 |
+
self.de_layer3_1 = nn.Sequential(
|
| 76 |
+
nn.ConvTranspose2d(320, 192, kernel_size=4, stride=2, padding=1),
|
| 77 |
+
self.activation,
|
| 78 |
+
)
|
| 79 |
+
head_num = 3
|
| 80 |
+
dim = 192
|
| 81 |
+
|
| 82 |
+
self.de_layer2_2 = nn.Sequential(
|
| 83 |
+
nn.Conv2d(192+128, 192, kernel_size=1, padding=0),
|
| 84 |
+
self.activation,
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
self.de_block_1 = Intra_SA(dim, head_num)
|
| 88 |
+
self.de_block_2 = Inter_SA(dim, head_num)
|
| 89 |
+
self.de_block_3 = Intra_SA(dim, head_num)
|
| 90 |
+
self.de_block_4 = Inter_SA(dim, head_num)
|
| 91 |
+
self.de_block_5 = Intra_SA(dim, head_num)
|
| 92 |
+
self.de_block_6 = Inter_SA(dim, head_num)
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
self.de_layer2_1 = nn.Sequential(
|
| 96 |
+
nn.ConvTranspose2d(192, 64, kernel_size=4, stride=2, padding=1),
|
| 97 |
+
self.activation,
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
self.de_layer1_3 = nn.Sequential(
|
| 101 |
+
nn.Conv2d(128, 64, kernel_size=1, padding=0),
|
| 102 |
+
self.activation,
|
| 103 |
+
nn.Conv2d(64, 64, kernel_size=3, padding=1))
|
| 104 |
+
self.de_layer1_2 = nn.Sequential(
|
| 105 |
+
nn.Conv2d(64, 64, kernel_size=3, padding=1),
|
| 106 |
+
self.activation,
|
| 107 |
+
nn.Conv2d(64, 64, kernel_size=3, padding=1))
|
| 108 |
+
self.de_layer1_1 = nn.Sequential(
|
| 109 |
+
nn.Conv2d(64, 3, kernel_size=3, padding=1),
|
| 110 |
+
self.activation
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
def forward(self, x, residual_1, residual_2):
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
hx = self.de_layer3_1(x)
|
| 117 |
+
|
| 118 |
+
hx = self.de_layer2_2(torch.cat((hx, residual_2), dim = 1))
|
| 119 |
+
hx = self.de_block_1(hx)
|
| 120 |
+
hx = self.de_block_2(hx)
|
| 121 |
+
hx = self.de_block_3(hx)
|
| 122 |
+
hx = self.de_block_4(hx)
|
| 123 |
+
hx = self.de_block_5(hx)
|
| 124 |
+
hx = self.de_block_6(hx)
|
| 125 |
+
hx = self.de_layer2_1(hx)
|
| 126 |
+
|
| 127 |
+
hx = self.activation(self.de_layer1_3(torch.cat((hx, residual_1), dim = 1)) + hx)
|
| 128 |
+
hx = self.activation(self.de_layer1_2(hx) + hx)
|
| 129 |
+
hx = self.de_layer1_1(hx)
|
| 130 |
+
|
| 131 |
+
return hx
|
| 132 |
+
|
| 133 |
+
class Attention(nn.Module):
|
| 134 |
+
def __init__(self, head_num):
|
| 135 |
+
super(Attention, self).__init__()
|
| 136 |
+
self.num_attention_heads = head_num
|
| 137 |
+
self.softmax = nn.Softmax(dim=-1)
|
| 138 |
+
|
| 139 |
+
def transpose_for_scores(self, x):
|
| 140 |
+
B, N, C = x.size()
|
| 141 |
+
attention_head_size = int(C / self.num_attention_heads)
|
| 142 |
+
new_x_shape = x.size()[:-1] + (self.num_attention_heads, attention_head_size)
|
| 143 |
+
x = x.view(*new_x_shape)
|
| 144 |
+
return x.permute(0, 2, 1, 3).contiguous()
|
| 145 |
+
|
| 146 |
+
def forward(self, query_layer, key_layer, value_layer):
|
| 147 |
+
B, N, C = query_layer.size()
|
| 148 |
+
query_layer = self.transpose_for_scores(query_layer)
|
| 149 |
+
key_layer = self.transpose_for_scores(key_layer)
|
| 150 |
+
value_layer = self.transpose_for_scores(value_layer)
|
| 151 |
+
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
| 152 |
+
_, _, _, d = query_layer.size()
|
| 153 |
+
attention_scores = attention_scores / math.sqrt(d)
|
| 154 |
+
attention_probs = self.softmax(attention_scores)
|
| 155 |
+
context_layer = torch.matmul(attention_probs, value_layer)
|
| 156 |
+
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
| 157 |
+
new_context_layer_shape = context_layer.size()[:-2] + (C,)
|
| 158 |
+
attention_out = context_layer.view(*new_context_layer_shape)
|
| 159 |
+
|
| 160 |
+
return attention_out
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
class Mlp(nn.Module):
|
| 164 |
+
def __init__(self, hidden_size):
|
| 165 |
+
super(Mlp, self).__init__()
|
| 166 |
+
self.fc1 = nn.Linear(hidden_size, 4*hidden_size)
|
| 167 |
+
self.fc2 = nn.Linear(4*hidden_size, hidden_size)
|
| 168 |
+
self.act_fn = torch.nn.functional.gelu
|
| 169 |
+
self._init_weights()
|
| 170 |
+
|
| 171 |
+
def _init_weights(self):
|
| 172 |
+
nn.init.xavier_uniform_(self.fc1.weight)
|
| 173 |
+
nn.init.xavier_uniform_(self.fc2.weight)
|
| 174 |
+
nn.init.normal_(self.fc1.bias, std=1e-6)
|
| 175 |
+
nn.init.normal_(self.fc2.bias, std=1e-6)
|
| 176 |
+
|
| 177 |
+
def forward(self, x):
|
| 178 |
+
x = self.fc1(x)
|
| 179 |
+
x = self.act_fn(x)
|
| 180 |
+
x = self.fc2(x)
|
| 181 |
+
return x
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
# CPE (Conditional Positional Embedding)
|
| 185 |
+
class PEG(nn.Module):
|
| 186 |
+
def __init__(self, hidden_size):
|
| 187 |
+
super(PEG, self).__init__()
|
| 188 |
+
self.PEG = nn.Conv2d(hidden_size, hidden_size, kernel_size=3, padding=1, groups=hidden_size)
|
| 189 |
+
|
| 190 |
+
def forward(self, x):
|
| 191 |
+
x = self.PEG(x) + x
|
| 192 |
+
return x
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
class Intra_SA(nn.Module):
|
| 196 |
+
def __init__(self, dim, head_num):
|
| 197 |
+
super(Intra_SA, self).__init__()
|
| 198 |
+
self.hidden_size = dim // 2
|
| 199 |
+
self.head_num = head_num
|
| 200 |
+
self.attention_norm = nn.LayerNorm(dim)
|
| 201 |
+
self.conv_input = nn.Conv2d(dim, dim, kernel_size=1, padding=0)
|
| 202 |
+
self.qkv_local_h = nn.Linear(self.hidden_size, self.hidden_size * 3) # qkv_h
|
| 203 |
+
self.qkv_local_v = nn.Linear(self.hidden_size, self.hidden_size * 3) # qkv_v
|
| 204 |
+
self.fuse_out = nn.Conv2d(dim, dim, kernel_size=1, padding=0)
|
| 205 |
+
self.ffn_norm = nn.LayerNorm(dim)
|
| 206 |
+
self.ffn = Mlp(dim)
|
| 207 |
+
self.attn = Attention(head_num=self.head_num)
|
| 208 |
+
self.PEG = PEG(dim)
|
| 209 |
+
def forward(self, x):
|
| 210 |
+
h = x
|
| 211 |
+
B, C, H, W = x.size()
|
| 212 |
+
|
| 213 |
+
x = x.view(B, C, H*W).permute(0, 2, 1).contiguous()
|
| 214 |
+
x = self.attention_norm(x).permute(0, 2, 1).contiguous()
|
| 215 |
+
x = x.view(B, C, H, W)
|
| 216 |
+
|
| 217 |
+
x_input = torch.chunk(self.conv_input(x), 2, dim=1)
|
| 218 |
+
feature_h = (x_input[0]).permute(0, 2, 3, 1).contiguous()
|
| 219 |
+
feature_h = feature_h.view(B * H, W, C//2)
|
| 220 |
+
feature_v = (x_input[1]).permute(0, 3, 2, 1).contiguous()
|
| 221 |
+
feature_v = feature_v.view(B * W, H, C//2)
|
| 222 |
+
qkv_h = torch.chunk(self.qkv_local_h(feature_h), 3, dim=2)
|
| 223 |
+
qkv_v = torch.chunk(self.qkv_local_v(feature_v), 3, dim=2)
|
| 224 |
+
q_h, k_h, v_h = qkv_h[0], qkv_h[1], qkv_h[2]
|
| 225 |
+
q_v, k_v, v_v = qkv_v[0], qkv_v[1], qkv_v[2]
|
| 226 |
+
|
| 227 |
+
if H == W:
|
| 228 |
+
query = torch.cat((q_h, q_v), dim=0)
|
| 229 |
+
key = torch.cat((k_h, k_v), dim=0)
|
| 230 |
+
value = torch.cat((v_h, v_v), dim=0)
|
| 231 |
+
attention_output = self.attn(query, key, value)
|
| 232 |
+
attention_output = torch.chunk(attention_output, 2, dim=0)
|
| 233 |
+
attention_output_h = attention_output[0]
|
| 234 |
+
attention_output_v = attention_output[1]
|
| 235 |
+
attention_output_h = attention_output_h.view(B, H, W, C//2).permute(0, 3, 1, 2).contiguous()
|
| 236 |
+
attention_output_v = attention_output_v.view(B, W, H, C//2).permute(0, 3, 2, 1).contiguous()
|
| 237 |
+
attn_out = self.fuse_out(torch.cat((attention_output_h, attention_output_v), dim=1))
|
| 238 |
+
else:
|
| 239 |
+
attention_output_h = self.attn(q_h, k_h, v_h)
|
| 240 |
+
attention_output_v = self.attn(q_v, k_v, v_v)
|
| 241 |
+
attention_output_h = attention_output_h.view(B, H, W, C//2).permute(0, 3, 1, 2).contiguous()
|
| 242 |
+
attention_output_v = attention_output_v.view(B, W, H, C//2).permute(0, 3, 2, 1).contiguous()
|
| 243 |
+
attn_out = self.fuse_out(torch.cat((attention_output_h, attention_output_v), dim=1))
|
| 244 |
+
|
| 245 |
+
x = attn_out + h
|
| 246 |
+
x = x.view(B, C, H*W).permute(0, 2, 1).contiguous()
|
| 247 |
+
h = x
|
| 248 |
+
x = self.ffn_norm(x)
|
| 249 |
+
x = self.ffn(x)
|
| 250 |
+
x = x + h
|
| 251 |
+
x = x.permute(0, 2, 1).contiguous()
|
| 252 |
+
x = x.view(B, C, H, W)
|
| 253 |
+
|
| 254 |
+
x = self.PEG(x)
|
| 255 |
+
|
| 256 |
+
return x
|
| 257 |
+
|
| 258 |
+
class Inter_SA(nn.Module):
|
| 259 |
+
def __init__(self,dim, head_num):
|
| 260 |
+
super(Inter_SA, self).__init__()
|
| 261 |
+
self.hidden_size = dim
|
| 262 |
+
self.head_num = head_num
|
| 263 |
+
self.attention_norm = nn.LayerNorm(self.hidden_size)
|
| 264 |
+
self.conv_input = nn.Conv2d(self.hidden_size, self.hidden_size, kernel_size=1, padding=0)
|
| 265 |
+
self.conv_h = nn.Conv2d(self.hidden_size//2, 3 * (self.hidden_size//2), kernel_size=1, padding=0) # qkv_h
|
| 266 |
+
self.conv_v = nn.Conv2d(self.hidden_size//2, 3 * (self.hidden_size//2), kernel_size=1, padding=0) # qkv_v
|
| 267 |
+
self.ffn_norm = nn.LayerNorm(self.hidden_size)
|
| 268 |
+
self.ffn = Mlp(self.hidden_size)
|
| 269 |
+
self.fuse_out = nn.Conv2d(self.hidden_size, self.hidden_size, kernel_size=1, padding=0)
|
| 270 |
+
self.attn = Attention(head_num=self.head_num)
|
| 271 |
+
self.PEG = PEG(dim)
|
| 272 |
+
|
| 273 |
+
def forward(self, x):
|
| 274 |
+
h = x
|
| 275 |
+
B, C, H, W = x.size()
|
| 276 |
+
|
| 277 |
+
x = x.view(B, C, H*W).permute(0, 2, 1).contiguous()
|
| 278 |
+
x = self.attention_norm(x).permute(0, 2, 1).contiguous()
|
| 279 |
+
x = x.view(B, C, H, W)
|
| 280 |
+
#print(x.shape)
|
| 281 |
+
|
| 282 |
+
x_input = torch.chunk(self.conv_input(x), 2, dim=1)
|
| 283 |
+
feature_h = torch.chunk(self.conv_h(x_input[0]), 3, dim=1)
|
| 284 |
+
feature_v = torch.chunk(self.conv_v(x_input[1]), 3, dim=1)
|
| 285 |
+
query_h, key_h, value_h = feature_h[0], feature_h[1], feature_h[2]
|
| 286 |
+
query_v, key_v, value_v = feature_v[0], feature_v[1], feature_v[2]
|
| 287 |
+
|
| 288 |
+
horizontal_groups = torch.cat((query_h, key_h, value_h), dim=0)
|
| 289 |
+
horizontal_groups = horizontal_groups.permute(0, 2, 1, 3).contiguous()
|
| 290 |
+
horizontal_groups = horizontal_groups.view(3*B, H, -1)
|
| 291 |
+
horizontal_groups = torch.chunk(horizontal_groups, 3, dim=0)
|
| 292 |
+
query_h, key_h, value_h = horizontal_groups[0], horizontal_groups[1], horizontal_groups[2]
|
| 293 |
+
|
| 294 |
+
vertical_groups = torch.cat((query_v, key_v, value_v), dim=0)
|
| 295 |
+
vertical_groups = vertical_groups.permute(0, 3, 1, 2).contiguous()
|
| 296 |
+
vertical_groups = vertical_groups.view(3*B, W, -1)
|
| 297 |
+
vertical_groups = torch.chunk(vertical_groups, 3, dim=0)
|
| 298 |
+
query_v, key_v, value_v = vertical_groups[0], vertical_groups[1], vertical_groups[2]
|
| 299 |
+
|
| 300 |
+
|
| 301 |
+
if H == W:
|
| 302 |
+
query = torch.cat((query_h, query_v), dim=0)
|
| 303 |
+
key = torch.cat((key_h, key_v), dim=0)
|
| 304 |
+
value = torch.cat((value_h, value_v), dim=0)
|
| 305 |
+
attention_output = self.attn(query, key, value)
|
| 306 |
+
attention_output = torch.chunk(attention_output, 2, dim=0)
|
| 307 |
+
attention_output_h = attention_output[0]
|
| 308 |
+
attention_output_v = attention_output[1]
|
| 309 |
+
attention_output_h = attention_output_h.view(B, H, C//2, W).permute(0, 2, 1, 3).contiguous()
|
| 310 |
+
attention_output_v = attention_output_v.view(B, W, C//2, H).permute(0, 2, 3, 1).contiguous()
|
| 311 |
+
attn_out = self.fuse_out(torch.cat((attention_output_h, attention_output_v), dim=1))
|
| 312 |
+
else:
|
| 313 |
+
attention_output_h = self.attn(query_h, key_h, value_h)
|
| 314 |
+
attention_output_v = self.attn(query_v, key_v, value_v)
|
| 315 |
+
attention_output_h = attention_output_h.view(B, H, C//2, W).permute(0, 2, 1, 3).contiguous()
|
| 316 |
+
attention_output_v = attention_output_v.view(B, W, C//2, H).permute(0, 2, 3, 1).contiguous()
|
| 317 |
+
attn_out = self.fuse_out(torch.cat((attention_output_h, attention_output_v), dim=1))
|
| 318 |
+
|
| 319 |
+
x = attn_out + h
|
| 320 |
+
x = x.view(B, C, H*W).permute(0, 2, 1).contiguous()
|
| 321 |
+
h = x
|
| 322 |
+
x = self.ffn_norm(x)
|
| 323 |
+
x = self.ffn(x)
|
| 324 |
+
x = x + h
|
| 325 |
+
x = x.permute(0, 2, 1).contiguous()
|
| 326 |
+
x = x.view(B, C, H, W)
|
| 327 |
+
|
| 328 |
+
x = self.PEG(x)
|
| 329 |
+
|
| 330 |
+
return x
|
| 331 |
+
|
| 332 |
+
##########################################################################
|
| 333 |
+
class Strip_VSSB(nn.Module):
|
| 334 |
+
def __init__(self, dim, head_num):
|
| 335 |
+
super(Strip_VSSB, self).__init__()
|
| 336 |
+
|
| 337 |
+
self.intra = Intra_SA(dim, head_num)
|
| 338 |
+
self.inter = Inter_SA(dim, head_num)
|
| 339 |
+
|
| 340 |
+
def forward(self, x):
|
| 341 |
+
x = self.intra(x)
|
| 342 |
+
x = self.inter(x)
|
| 343 |
+
|
| 344 |
+
return x
|
| 345 |
+
|
| 346 |
+
class Stripformer(nn.Module):
|
| 347 |
+
def __init__(self):
|
| 348 |
+
super(Stripformer, self).__init__()
|
| 349 |
+
|
| 350 |
+
self.encoder = Embeddings()
|
| 351 |
+
head_num = 5
|
| 352 |
+
dim = 320
|
| 353 |
+
self.Trans_block_1 = Intra_SA(dim, head_num)
|
| 354 |
+
self.Trans_block_2 = Inter_SA(dim, head_num)
|
| 355 |
+
self.Trans_block_3 = Intra_SA(dim, head_num)
|
| 356 |
+
self.Trans_block_4 = Inter_SA(dim, head_num)
|
| 357 |
+
self.Trans_block_5 = Intra_SA(dim, head_num)
|
| 358 |
+
self.Trans_block_6 = Inter_SA(dim, head_num)
|
| 359 |
+
self.Trans_block_7 = Intra_SA(dim, head_num)
|
| 360 |
+
self.Trans_block_8 = Inter_SA(dim, head_num)
|
| 361 |
+
self.Trans_block_9 = Intra_SA(dim, head_num)
|
| 362 |
+
self.Trans_block_10 = Inter_SA(dim, head_num)
|
| 363 |
+
self.Trans_block_11 = Intra_SA(dim, head_num)
|
| 364 |
+
self.Trans_block_12 = Inter_SA(dim, head_num)
|
| 365 |
+
self.decoder = Embeddings_output()
|
| 366 |
+
|
| 367 |
+
|
| 368 |
+
def forward(self, x):
|
| 369 |
+
|
| 370 |
+
hx, residual_1, residual_2 = self.encoder(x)
|
| 371 |
+
hx = self.Trans_block_1(hx)
|
| 372 |
+
hx = self.Trans_block_2(hx)
|
| 373 |
+
hx = self.Trans_block_3(hx)
|
| 374 |
+
hx = self.Trans_block_4(hx)
|
| 375 |
+
hx = self.Trans_block_5(hx)
|
| 376 |
+
hx = self.Trans_block_6(hx)
|
| 377 |
+
hx = self.Trans_block_7(hx)
|
| 378 |
+
hx = self.Trans_block_8(hx)
|
| 379 |
+
hx = self.Trans_block_9(hx)
|
| 380 |
+
hx = self.Trans_block_10(hx)
|
| 381 |
+
hx = self.Trans_block_11(hx)
|
| 382 |
+
hx = self.Trans_block_12(hx)
|
| 383 |
+
hx = self.decoder(hx, residual_1, residual_2)
|
| 384 |
+
|
| 385 |
+
return hx + x
|
| 386 |
+
|
| 387 |
+
#"""
|
| 388 |
+
import time
|
| 389 |
+
start_time = time.time()
|
| 390 |
+
inp = torch.randn(1, 3, 64, 64).cuda()
|
| 391 |
+
model = Stripformer().cuda()
|
| 392 |
+
out = model(inp)
|
| 393 |
+
print(out.shape)
|
| 394 |
+
print("--- %s seconds ---" % (time.time() - start_time))
|
| 395 |
+
pytorch_total_params = sum(p.numel() for p in model.parameters())
|
| 396 |
+
print("--- {num} parameters ---".format(num = pytorch_total_params))
|
| 397 |
+
pytorch_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 398 |
+
print("--- {num} trainable parameters ---".format(num = pytorch_trainable_params))
|
| 399 |
+
gpu_memmem_usage_bytes = torch.cuda.max_memory_allocated()
|
| 400 |
+
print(gpu_memmem_usage_bytes / 1024 / 1024 / 1024) # 64: 0.37 128: 0.84 -> 256: 3.02 -> 512: 12.55
|
| 401 |
+
#"""
|
| 402 |
+
"""
|
| 403 |
+
import torch
|
| 404 |
+
from ptflops import get_model_complexity_info
|
| 405 |
+
|
| 406 |
+
with torch.cuda.device(0):
|
| 407 |
+
net = model
|
| 408 |
+
macs, params = get_model_complexity_info(net, (3, 512, 512), as_strings=True,
|
| 409 |
+
print_per_layer_stat=True, verbose=True)
|
| 410 |
+
print('{:<30} {:<8}'.format('Computational complexity: ', macs)) # 49.79 GMac
|
| 411 |
+
print('{:<30} {:<8}'.format('Number of parameters: ', params)) # 6.06 M
|
| 412 |
+
"""
|
| 413 |
+
"""
|
| 414 |
+
import time
|
| 415 |
+
start_time = time.time()
|
| 416 |
+
inp = torch.randn(1, 32, 512, 512).cuda().to(dtype=torch.float32)
|
| 417 |
+
model = Strip_VSSB(dim=32, head_num = 4).cuda().to(dtype=torch.float32)
|
| 418 |
+
out = model(inp)
|
| 419 |
+
print(out.shape)
|
| 420 |
+
print("--- %s seconds ---" % (time.time() - start_time))
|
| 421 |
+
pytorch_total_params = sum(p.numel() for p in model.parameters())
|
| 422 |
+
print("--- {num} parameters ---".format(num = pytorch_total_params))
|
| 423 |
+
pytorch_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 424 |
+
print("--- {num} trainable parameters ---".format(num = pytorch_trainable_params))
|
| 425 |
+
gpu_memmem_usage_bytes = torch.cuda.max_memory_allocated()
|
| 426 |
+
print(gpu_memmem_usage_bytes / 1024 / 1024 / 1024) # 128: 0.84 -> 256: 3.02 -> 512: 12.55
|
| 427 |
+
"""
|
| 428 |
+
|
| 429 |
+
|
models/sota/XYScanNet.py
ADDED
|
@@ -0,0 +1,754 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numbers
|
| 2 |
+
import math
|
| 3 |
+
from typing import Optional
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
from torch import Tensor
|
| 9 |
+
|
| 10 |
+
from einops import rearrange, repeat
|
| 11 |
+
|
| 12 |
+
from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn
|
| 13 |
+
|
| 14 |
+
try:
|
| 15 |
+
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
|
| 16 |
+
except ImportError:
|
| 17 |
+
causal_conv1d_fn, causal_conv1d_update = None, None
|
| 18 |
+
|
| 19 |
+
try:
|
| 20 |
+
from mamba_ssm.ops.triton.selective_state_update import selective_state_update
|
| 21 |
+
except ImportError:
|
| 22 |
+
selective_state_update = None
|
| 23 |
+
|
| 24 |
+
try:
|
| 25 |
+
from mamba_ssm.ops.triton.layer_norm import RMSNorm, layer_norm_fn, rms_norm_fn
|
| 26 |
+
except ImportError:
|
| 27 |
+
RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def to_3d(x):
|
| 31 |
+
return rearrange(x, 'b c h w -> b (h w) c')
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def to_4d(x, h, w):
|
| 35 |
+
return rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class BiasFree_LayerNorm(nn.Module):
|
| 39 |
+
def __init__(self, normalized_shape):
|
| 40 |
+
super(BiasFree_LayerNorm, self).__init__()
|
| 41 |
+
if isinstance(normalized_shape, numbers.Integral):
|
| 42 |
+
normalized_shape = (normalized_shape,)
|
| 43 |
+
normalized_shape = torch.Size(normalized_shape)
|
| 44 |
+
|
| 45 |
+
assert len(normalized_shape) == 1
|
| 46 |
+
|
| 47 |
+
self.weight = nn.Parameter(torch.ones(normalized_shape))
|
| 48 |
+
self.normalized_shape = normalized_shape
|
| 49 |
+
|
| 50 |
+
def forward(self, x):
|
| 51 |
+
sigma = x.var(-1, keepdim=True, unbiased=False)
|
| 52 |
+
return x / torch.sqrt(sigma + 1e-5) * self.weight
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class WithBias_LayerNorm(nn.Module):
|
| 56 |
+
def __init__(self, normalized_shape):
|
| 57 |
+
super(WithBias_LayerNorm, self).__init__()
|
| 58 |
+
if isinstance(normalized_shape, numbers.Integral):
|
| 59 |
+
normalized_shape = (normalized_shape,)
|
| 60 |
+
normalized_shape = torch.Size(normalized_shape)
|
| 61 |
+
|
| 62 |
+
assert len(normalized_shape) == 1
|
| 63 |
+
|
| 64 |
+
self.weight = nn.Parameter(torch.ones(normalized_shape))
|
| 65 |
+
self.bias = nn.Parameter(torch.zeros(normalized_shape))
|
| 66 |
+
self.normalized_shape = normalized_shape
|
| 67 |
+
|
| 68 |
+
def forward(self, x):
|
| 69 |
+
mu = x.mean(-1, keepdim=True)
|
| 70 |
+
sigma = x.var(-1, keepdim=True, unbiased=False)
|
| 71 |
+
return (x - mu) / torch.sqrt(sigma + 1e-5) * self.weight + self.bias
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
class LayerNorm(nn.Module):
|
| 75 |
+
def __init__(self, dim, LayerNorm_type):
|
| 76 |
+
super(LayerNorm, self).__init__()
|
| 77 |
+
if LayerNorm_type == 'BiasFree':
|
| 78 |
+
self.body = BiasFree_LayerNorm(dim)
|
| 79 |
+
else:
|
| 80 |
+
self.body = WithBias_LayerNorm(dim)
|
| 81 |
+
|
| 82 |
+
def forward(self, x):
|
| 83 |
+
h, w = x.shape[-2:]
|
| 84 |
+
return to_4d(self.body(to_3d(x)), h, w)
|
| 85 |
+
|
| 86 |
+
##########################################################################
|
| 87 |
+
def conv(in_channels, out_channels, kernel_size, bias=False, stride = 1):
|
| 88 |
+
return nn.Conv2d(
|
| 89 |
+
in_channels, out_channels, kernel_size,
|
| 90 |
+
padding=(kernel_size//2), bias=bias, stride = stride)
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
"""
|
| 94 |
+
Borrow from "https://github.com/state-spaces/mamba.git"
|
| 95 |
+
@article{mamba,
|
| 96 |
+
title={Mamba: Linear-Time Sequence Modeling with Selective State Spaces},
|
| 97 |
+
author={Gu, Albert and Dao, Tri},
|
| 98 |
+
journal={arXiv preprint arXiv:2312.00752},
|
| 99 |
+
year={2023}
|
| 100 |
+
}
|
| 101 |
+
"""
|
| 102 |
+
class Mamba(nn.Module):
|
| 103 |
+
def __init__(
|
| 104 |
+
self,
|
| 105 |
+
d_model,
|
| 106 |
+
d_state=16,
|
| 107 |
+
d_conv=4,
|
| 108 |
+
expand=2,
|
| 109 |
+
dt_rank="auto",
|
| 110 |
+
dt_min=0.001,
|
| 111 |
+
dt_max=0.1,
|
| 112 |
+
dt_init="random",
|
| 113 |
+
dt_scale=1.0,
|
| 114 |
+
dt_init_floor=1e-4,
|
| 115 |
+
conv_bias=True,
|
| 116 |
+
bias=False,
|
| 117 |
+
use_fast_path=True, # Fused kernel options
|
| 118 |
+
layer_idx=None,
|
| 119 |
+
device=None,
|
| 120 |
+
dtype=None,
|
| 121 |
+
):
|
| 122 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
| 123 |
+
super().__init__()
|
| 124 |
+
self.d_model = d_model
|
| 125 |
+
self.d_state = d_state
|
| 126 |
+
self.d_conv = d_conv
|
| 127 |
+
self.expand = expand
|
| 128 |
+
self.d_inner = int(self.expand * self.d_model)
|
| 129 |
+
self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank
|
| 130 |
+
self.use_fast_path = use_fast_path
|
| 131 |
+
self.layer_idx = layer_idx
|
| 132 |
+
|
| 133 |
+
self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=bias, **factory_kwargs)
|
| 134 |
+
|
| 135 |
+
self.conv1d = nn.Conv1d(
|
| 136 |
+
in_channels=self.d_inner,
|
| 137 |
+
out_channels=self.d_inner,
|
| 138 |
+
bias=conv_bias,
|
| 139 |
+
kernel_size=d_conv,
|
| 140 |
+
groups=self.d_inner,
|
| 141 |
+
padding=d_conv - 1,
|
| 142 |
+
**factory_kwargs,
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
self.activation = "silu"
|
| 146 |
+
self.act = nn.SiLU()
|
| 147 |
+
|
| 148 |
+
self.x_proj = nn.Linear(
|
| 149 |
+
self.d_inner, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs
|
| 150 |
+
)
|
| 151 |
+
self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True, **factory_kwargs)
|
| 152 |
+
|
| 153 |
+
# Initialize special dt projection to preserve variance at initialization
|
| 154 |
+
dt_init_std = self.dt_rank**-0.5 * dt_scale
|
| 155 |
+
if dt_init == "constant":
|
| 156 |
+
nn.init.constant_(self.dt_proj.weight, dt_init_std)
|
| 157 |
+
elif dt_init == "random":
|
| 158 |
+
nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std)
|
| 159 |
+
else:
|
| 160 |
+
raise NotImplementedError
|
| 161 |
+
|
| 162 |
+
# Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max
|
| 163 |
+
dt = torch.exp(
|
| 164 |
+
torch.rand(self.d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min))
|
| 165 |
+
+ math.log(dt_min)
|
| 166 |
+
).clamp(min=dt_init_floor)
|
| 167 |
+
# Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
|
| 168 |
+
inv_dt = dt + torch.log(-torch.expm1(-dt))
|
| 169 |
+
with torch.no_grad():
|
| 170 |
+
self.dt_proj.bias.copy_(inv_dt)
|
| 171 |
+
# Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit
|
| 172 |
+
self.dt_proj.bias._no_reinit = True
|
| 173 |
+
|
| 174 |
+
# S4D real initialization
|
| 175 |
+
A = repeat(
|
| 176 |
+
torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device),
|
| 177 |
+
"n -> d n",
|
| 178 |
+
d=self.d_inner,
|
| 179 |
+
).contiguous()
|
| 180 |
+
A_log = torch.log(A) # Keep A_log in fp32
|
| 181 |
+
self.A_log = nn.Parameter(A_log)
|
| 182 |
+
self.A_log._no_weight_decay = True
|
| 183 |
+
|
| 184 |
+
# D "skip" parameter
|
| 185 |
+
self.D = nn.Parameter(torch.ones(self.d_inner, device=device)) # Keep in fp32
|
| 186 |
+
self.D._no_weight_decay = True
|
| 187 |
+
|
| 188 |
+
self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs)
|
| 189 |
+
|
| 190 |
+
def forward(self, hidden_states, inference_params=None):
|
| 191 |
+
"""
|
| 192 |
+
hidden_states: (B, L, D)
|
| 193 |
+
Returns: same shape as hidden_states
|
| 194 |
+
"""
|
| 195 |
+
batch, seqlen, dim = hidden_states.shape
|
| 196 |
+
|
| 197 |
+
conv_state, ssm_state = None, None
|
| 198 |
+
if inference_params is not None:
|
| 199 |
+
conv_state, ssm_state = self._get_states_from_cache(inference_params, batch)
|
| 200 |
+
if inference_params.seqlen_offset > 0:
|
| 201 |
+
# The states are updated inplace
|
| 202 |
+
out, _, _ = self.step(hidden_states, conv_state, ssm_state)
|
| 203 |
+
return out
|
| 204 |
+
|
| 205 |
+
# We do matmul and transpose BLH -> HBL at the same time
|
| 206 |
+
xz = rearrange(
|
| 207 |
+
self.in_proj.weight @ rearrange(hidden_states, "b l d -> d (b l)"),
|
| 208 |
+
"d (b l) -> b d l",
|
| 209 |
+
l=seqlen,
|
| 210 |
+
)
|
| 211 |
+
if self.in_proj.bias is not None:
|
| 212 |
+
xz = xz + rearrange(self.in_proj.bias.to(dtype=xz.dtype), "d -> d 1")
|
| 213 |
+
|
| 214 |
+
A = -torch.exp(self.A_log.float()) # (d_inner, d_state)
|
| 215 |
+
# In the backward pass we write dx and dz next to each other to avoid torch.cat
|
| 216 |
+
if self.use_fast_path and causal_conv1d_fn is not None and inference_params is None: # Doesn't support outputting the states
|
| 217 |
+
out = mamba_inner_fn(
|
| 218 |
+
xz,
|
| 219 |
+
self.conv1d.weight,
|
| 220 |
+
self.conv1d.bias,
|
| 221 |
+
self.x_proj.weight,
|
| 222 |
+
self.dt_proj.weight,
|
| 223 |
+
self.out_proj.weight,
|
| 224 |
+
self.out_proj.bias,
|
| 225 |
+
A,
|
| 226 |
+
None, # input-dependent B
|
| 227 |
+
None, # input-dependent C
|
| 228 |
+
self.D.float(),
|
| 229 |
+
delta_bias=self.dt_proj.bias.float(),
|
| 230 |
+
delta_softplus=True,
|
| 231 |
+
)
|
| 232 |
+
else:
|
| 233 |
+
x, z = xz.chunk(2, dim=1)
|
| 234 |
+
# Compute short convolution
|
| 235 |
+
if conv_state is not None:
|
| 236 |
+
# If we just take x[:, :, -self.d_conv :], it will error if seqlen < self.d_conv
|
| 237 |
+
# Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise.
|
| 238 |
+
conv_state.copy_(F.pad(x, (self.d_conv - x.shape[-1], 0))) # Update state (B D W)
|
| 239 |
+
if causal_conv1d_fn is None:
|
| 240 |
+
x = self.act(self.conv1d(x)[..., :seqlen])
|
| 241 |
+
else:
|
| 242 |
+
assert self.activation in ["silu", "swish"]
|
| 243 |
+
x = causal_conv1d_fn(
|
| 244 |
+
x=x,
|
| 245 |
+
weight=rearrange(self.conv1d.weight, "d 1 w -> d w"),
|
| 246 |
+
bias=self.conv1d.bias,
|
| 247 |
+
activation=self.activation,
|
| 248 |
+
)
|
| 249 |
+
|
| 250 |
+
# We're careful here about the layout, to avoid extra transposes.
|
| 251 |
+
# We want dt to have d as the slowest moving dimension
|
| 252 |
+
# and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
|
| 253 |
+
x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d")) # (bl d)
|
| 254 |
+
dt, B, C = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1)
|
| 255 |
+
dt = self.dt_proj.weight @ dt.t()
|
| 256 |
+
dt = rearrange(dt, "d (b l) -> b d l", l=seqlen)
|
| 257 |
+
B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
|
| 258 |
+
C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
|
| 259 |
+
assert self.activation in ["silu", "swish"]
|
| 260 |
+
y = selective_scan_fn(
|
| 261 |
+
x,
|
| 262 |
+
dt,
|
| 263 |
+
A,
|
| 264 |
+
B,
|
| 265 |
+
C,
|
| 266 |
+
self.D.float(),
|
| 267 |
+
z=z,
|
| 268 |
+
delta_bias=self.dt_proj.bias.float(),
|
| 269 |
+
delta_softplus=True,
|
| 270 |
+
return_last_state=ssm_state is not None,
|
| 271 |
+
)
|
| 272 |
+
if ssm_state is not None:
|
| 273 |
+
y, last_state = y
|
| 274 |
+
ssm_state.copy_(last_state)
|
| 275 |
+
y = rearrange(y, "b d l -> b l d")
|
| 276 |
+
out = self.out_proj(y)
|
| 277 |
+
return out
|
| 278 |
+
|
| 279 |
+
def step(self, hidden_states, conv_state, ssm_state):
|
| 280 |
+
dtype = hidden_states.dtype
|
| 281 |
+
assert hidden_states.shape[1] == 1, "Only support decoding with 1 token at a time for now"
|
| 282 |
+
xz = self.in_proj(hidden_states.squeeze(1)) # (B 2D)
|
| 283 |
+
x, z = xz.chunk(2, dim=-1) # (B D)
|
| 284 |
+
|
| 285 |
+
# Conv step
|
| 286 |
+
if causal_conv1d_update is None:
|
| 287 |
+
conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W)
|
| 288 |
+
conv_state[:, :, -1] = x
|
| 289 |
+
x = torch.sum(conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1) # (B D)
|
| 290 |
+
if self.conv1d.bias is not None:
|
| 291 |
+
x = x + self.conv1d.bias
|
| 292 |
+
x = self.act(x).to(dtype=dtype)
|
| 293 |
+
else:
|
| 294 |
+
x = causal_conv1d_update(
|
| 295 |
+
x,
|
| 296 |
+
conv_state,
|
| 297 |
+
rearrange(self.conv1d.weight, "d 1 w -> d w"),
|
| 298 |
+
self.conv1d.bias,
|
| 299 |
+
self.activation,
|
| 300 |
+
)
|
| 301 |
+
|
| 302 |
+
x_db = self.x_proj(x) # (B dt_rank+2*d_state)
|
| 303 |
+
dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1)
|
| 304 |
+
# Don't add dt_bias here
|
| 305 |
+
dt = F.linear(dt, self.dt_proj.weight) # (B d_inner)
|
| 306 |
+
A = -torch.exp(self.A_log.float()) # (d_inner, d_state)
|
| 307 |
+
|
| 308 |
+
# SSM step
|
| 309 |
+
if selective_state_update is None:
|
| 310 |
+
# Discretize A and B
|
| 311 |
+
dt = F.softplus(dt + self.dt_proj.bias.to(dtype=dt.dtype))
|
| 312 |
+
dA = torch.exp(torch.einsum("bd,dn->bdn", dt, A))
|
| 313 |
+
dB = torch.einsum("bd,bn->bdn", dt, B)
|
| 314 |
+
ssm_state.copy_(ssm_state * dA + rearrange(x, "b d -> b d 1") * dB)
|
| 315 |
+
y = torch.einsum("bdn,bn->bd", ssm_state.to(dtype), C)
|
| 316 |
+
y = y + self.D.to(dtype) * x
|
| 317 |
+
y = y * self.act(z) # (B D)
|
| 318 |
+
else:
|
| 319 |
+
y = selective_state_update(
|
| 320 |
+
ssm_state, x, dt, A, B, C, self.D, z=z, dt_bias=self.dt_proj.bias, dt_softplus=True
|
| 321 |
+
)
|
| 322 |
+
|
| 323 |
+
out = self.out_proj(y)
|
| 324 |
+
return out.unsqueeze(1), conv_state, ssm_state
|
| 325 |
+
|
| 326 |
+
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
| 327 |
+
device = self.out_proj.weight.device
|
| 328 |
+
conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype
|
| 329 |
+
conv_state = torch.zeros(
|
| 330 |
+
batch_size, self.d_model * self.expand, self.d_conv, device=device, dtype=conv_dtype
|
| 331 |
+
)
|
| 332 |
+
ssm_dtype = self.dt_proj.weight.dtype if dtype is None else dtype
|
| 333 |
+
# ssm_dtype = torch.float32
|
| 334 |
+
ssm_state = torch.zeros(
|
| 335 |
+
batch_size, self.d_model * self.expand, self.d_state, device=device, dtype=ssm_dtype
|
| 336 |
+
)
|
| 337 |
+
return conv_state, ssm_state
|
| 338 |
+
|
| 339 |
+
def _get_states_from_cache(self, inference_params, batch_size, initialize_states=False):
|
| 340 |
+
assert self.layer_idx is not None
|
| 341 |
+
if self.layer_idx not in inference_params.key_value_memory_dict:
|
| 342 |
+
batch_shape = (batch_size,)
|
| 343 |
+
conv_state = torch.zeros(
|
| 344 |
+
batch_size,
|
| 345 |
+
self.d_model * self.expand,
|
| 346 |
+
self.d_conv,
|
| 347 |
+
device=self.conv1d.weight.device,
|
| 348 |
+
dtype=self.conv1d.weight.dtype,
|
| 349 |
+
)
|
| 350 |
+
ssm_state = torch.zeros(
|
| 351 |
+
batch_size,
|
| 352 |
+
self.d_model * self.expand,
|
| 353 |
+
self.d_state,
|
| 354 |
+
device=self.dt_proj.weight.device,
|
| 355 |
+
dtype=self.dt_proj.weight.dtype,
|
| 356 |
+
# dtype=torch.float32,
|
| 357 |
+
)
|
| 358 |
+
inference_params.key_value_memory_dict[self.layer_idx] = (conv_state, ssm_state)
|
| 359 |
+
else:
|
| 360 |
+
conv_state, ssm_state = inference_params.key_value_memory_dict[self.layer_idx]
|
| 361 |
+
# TODO: What if batch size changes between generation, and we reuse the same states?
|
| 362 |
+
if initialize_states:
|
| 363 |
+
conv_state.zero_()
|
| 364 |
+
ssm_state.zero_()
|
| 365 |
+
return conv_state, ssm_state
|
| 366 |
+
|
| 367 |
+
##########################################################################
|
| 368 |
+
## Feed-forward Network
|
| 369 |
+
class FFN(nn.Module):
|
| 370 |
+
def __init__(self, dim, ffn_expansion_factor, bias):
|
| 371 |
+
super(FFN, self).__init__()
|
| 372 |
+
|
| 373 |
+
hidden_features = int(dim*ffn_expansion_factor)
|
| 374 |
+
|
| 375 |
+
self.project_in = nn.Conv2d(dim, hidden_features*2, kernel_size=1, bias=bias)
|
| 376 |
+
|
| 377 |
+
self.dwconv = nn.Conv2d(hidden_features*2, hidden_features*2, kernel_size=3, stride=1, padding=1, groups=hidden_features*2, bias=bias, dilation=1)
|
| 378 |
+
|
| 379 |
+
self.win_size = 8
|
| 380 |
+
|
| 381 |
+
self.modulator = nn.Parameter(torch.ones(self.win_size, self.win_size, dim*2)) # modulator
|
| 382 |
+
|
| 383 |
+
self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias)
|
| 384 |
+
|
| 385 |
+
def forward(self, x):
|
| 386 |
+
b, c, h, w = x.shape
|
| 387 |
+
h1, w1 = h//self.win_size, w//self.win_size
|
| 388 |
+
x = self.project_in(x)
|
| 389 |
+
x = self.dwconv(x)
|
| 390 |
+
x_win = rearrange(x, 'b c (wsh h1) (wsw w1) -> b h1 w1 wsh wsw c', wsh=self.win_size, wsw=self.win_size)
|
| 391 |
+
x_win = x_win * self.modulator
|
| 392 |
+
x = rearrange(x_win, 'b h1 w1 wsh wsw c -> b c (wsh h1) (wsw w1)', wsh=self.win_size, wsw=self.win_size, h1=h1, w1=w1)
|
| 393 |
+
x1, x2 = x.chunk(2, dim=1)
|
| 394 |
+
x = x1 * x2
|
| 395 |
+
x = self.project_out(x)
|
| 396 |
+
return x
|
| 397 |
+
|
| 398 |
+
|
| 399 |
+
##########################################################################
|
| 400 |
+
## Gated Depth-wise Feed-forward Network (GDFN)
|
| 401 |
+
class GDFN(nn.Module):
|
| 402 |
+
def __init__(self, dim, ffn_expansion_factor, bias):
|
| 403 |
+
super(GDFN, self).__init__()
|
| 404 |
+
|
| 405 |
+
hidden_features = int(dim*ffn_expansion_factor)
|
| 406 |
+
|
| 407 |
+
self.project_in = nn.Conv2d(dim, hidden_features*2, kernel_size=1, bias=bias)
|
| 408 |
+
|
| 409 |
+
self.dwconv = nn.Conv2d(hidden_features*2, hidden_features*2, kernel_size=3, stride=1, padding=1, groups=hidden_features*2, bias=bias, dilation=1)
|
| 410 |
+
|
| 411 |
+
self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias)
|
| 412 |
+
|
| 413 |
+
def forward(self, x):
|
| 414 |
+
x = self.project_in(x)
|
| 415 |
+
x = self.dwconv(x)
|
| 416 |
+
x1, x2 = x.chunk(2, dim=1)
|
| 417 |
+
x = F.silu(x1) * x2
|
| 418 |
+
x = self.project_out(x)
|
| 419 |
+
return x
|
| 420 |
+
|
| 421 |
+
|
| 422 |
+
##########################################################################
|
| 423 |
+
## Overlapped image patch embedding with 3x3 Conv
|
| 424 |
+
class OverlapPatchEmbed(nn.Module):
|
| 425 |
+
def __init__(self, in_c=3, embed_dim=48, bias=False):
|
| 426 |
+
super(OverlapPatchEmbed, self).__init__()
|
| 427 |
+
|
| 428 |
+
self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=3, stride=1, padding=1, bias=bias)
|
| 429 |
+
|
| 430 |
+
def forward(self, x):
|
| 431 |
+
x = self.proj(x)
|
| 432 |
+
|
| 433 |
+
return x
|
| 434 |
+
|
| 435 |
+
|
| 436 |
+
##########################################################################
|
| 437 |
+
## Resizing modules
|
| 438 |
+
class Downsample(nn.Module):
|
| 439 |
+
def __init__(self, n_feat):
|
| 440 |
+
super(Downsample, self).__init__()
|
| 441 |
+
|
| 442 |
+
self.body = nn.Sequential(nn.Upsample(scale_factor=0.5, mode='bilinear', align_corners=False),
|
| 443 |
+
nn.Conv2d(n_feat, n_feat * 2, 3, stride=1, padding=1, bias=False))
|
| 444 |
+
|
| 445 |
+
def forward(self, x):
|
| 446 |
+
return self.body(x)
|
| 447 |
+
|
| 448 |
+
|
| 449 |
+
class Upsample(nn.Module):
|
| 450 |
+
def __init__(self, n_feat):
|
| 451 |
+
super(Upsample, self).__init__()
|
| 452 |
+
|
| 453 |
+
self.body = nn.Sequential(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
|
| 454 |
+
nn.Conv2d(n_feat, n_feat // 2, 3, stride=1, padding=1, bias=False))
|
| 455 |
+
|
| 456 |
+
def forward(self, x):
|
| 457 |
+
return self.body(x)
|
| 458 |
+
|
| 459 |
+
|
| 460 |
+
"""
|
| 461 |
+
Borrow from "https://github.com/pp00704831/Stripformer-ECCV-2022-.git"
|
| 462 |
+
@inproceedings{Tsai2022Stripformer,
|
| 463 |
+
author = {Fu-Jen Tsai and Yan-Tsung Peng and Yen-Yu Lin and Chung-Chi Tsai and Chia-Wen Lin},
|
| 464 |
+
title = {Stripformer: Strip Transformer for Fast Image Deblurring},
|
| 465 |
+
booktitle = {ECCV},
|
| 466 |
+
year = {2022}
|
| 467 |
+
}
|
| 468 |
+
"""
|
| 469 |
+
class Intra_VSSM(nn.Module):
|
| 470 |
+
def __init__(self, dim, vssm_expansion_factor, bias): # gated = True
|
| 471 |
+
super(Intra_VSSM, self).__init__()
|
| 472 |
+
hidden = int(dim*vssm_expansion_factor)
|
| 473 |
+
|
| 474 |
+
self.proj_in = nn.Conv2d(dim, hidden*2, kernel_size=1, bias=bias)
|
| 475 |
+
self.dwconv = nn.Conv2d(hidden*2, hidden*2, kernel_size=3, stride=1, padding=1, groups=hidden*2, bias=bias)
|
| 476 |
+
self.proj_out = nn.Conv2d(hidden, dim, kernel_size=1, bias=bias)
|
| 477 |
+
|
| 478 |
+
self.conv_input = nn.Conv2d(hidden, hidden, kernel_size=1, padding=0, bias=bias)
|
| 479 |
+
self.fuse_out = nn.Conv2d(hidden, hidden, kernel_size=1, padding=0, bias=bias)
|
| 480 |
+
self.mamba = Mamba(d_model=hidden // 2)
|
| 481 |
+
|
| 482 |
+
def forward_core(self, x):
|
| 483 |
+
B, C, H, W = x.size()
|
| 484 |
+
|
| 485 |
+
x_input = torch.chunk(self.conv_input(x), 2, dim=1)
|
| 486 |
+
|
| 487 |
+
feature_h = (x_input[0]).permute(0, 2, 3, 1).contiguous()
|
| 488 |
+
feature_h = feature_h.view(B * H, W, C//2)
|
| 489 |
+
|
| 490 |
+
feature_v = (x_input[1]).permute(0, 3, 2, 1).contiguous()
|
| 491 |
+
feature_v = feature_v.view(B * W, H, C//2)
|
| 492 |
+
|
| 493 |
+
if H == W:
|
| 494 |
+
feature = torch.cat((feature_h, feature_v), dim=0) # B * H * 2, W, C//2
|
| 495 |
+
scan_output = self.mamba(feature)
|
| 496 |
+
scan_output = torch.chunk(scan_output, 2, dim=0)
|
| 497 |
+
scan_output_h = scan_output[0]
|
| 498 |
+
scan_output_v = scan_output[1]
|
| 499 |
+
else:
|
| 500 |
+
scan_output_h = self.mamba(feature_h)
|
| 501 |
+
scan_output_v = self.mamba(feature_v)
|
| 502 |
+
|
| 503 |
+
scan_output_h = scan_output_h.view(B, H, W, C//2).permute(0, 3, 1, 2).contiguous()
|
| 504 |
+
scan_output_v = scan_output_v.view(B, W, H, C//2).permute(0, 3, 2, 1).contiguous()
|
| 505 |
+
scan_output = self.fuse_out(torch.cat((scan_output_h, scan_output_v), dim=1))
|
| 506 |
+
|
| 507 |
+
return scan_output
|
| 508 |
+
|
| 509 |
+
def forward(self, x):
|
| 510 |
+
x = self.proj_in(x)
|
| 511 |
+
x, x_ = self.dwconv(x).chunk(2, dim=1)
|
| 512 |
+
x = self.forward_core(x)
|
| 513 |
+
x = F.silu(x_) * x
|
| 514 |
+
x = self.proj_out(x)
|
| 515 |
+
return x
|
| 516 |
+
|
| 517 |
+
|
| 518 |
+
class Inter_VSSM(nn.Module):
|
| 519 |
+
def __init__(self, dim, vssm_expansion_factor, bias): # gated = True
|
| 520 |
+
super(Inter_VSSM, self).__init__()
|
| 521 |
+
hidden = int(dim*vssm_expansion_factor)
|
| 522 |
+
|
| 523 |
+
self.proj_in = nn.Conv2d(dim, hidden*2, kernel_size=1, bias=bias)
|
| 524 |
+
self.dwconv = nn.Conv2d(hidden*2, hidden*2, kernel_size=3, stride=1, padding=1, groups=hidden*2, bias=bias)
|
| 525 |
+
self.proj_out = nn.Conv2d(hidden, dim, kernel_size=1, bias=bias)
|
| 526 |
+
|
| 527 |
+
self.avg_pool = nn.AdaptiveAvgPool2d((None,1))
|
| 528 |
+
self.conv_input = nn.Conv2d(hidden, hidden, kernel_size=1, padding=0, bias=bias)
|
| 529 |
+
self.fuse_out = nn.Conv2d(hidden, hidden, kernel_size=1, padding=0, bias=bias)
|
| 530 |
+
self.mamba = Mamba(d_model=hidden // 2)
|
| 531 |
+
self.sigmoid = nn.Sigmoid()
|
| 532 |
+
|
| 533 |
+
def forward_core(self, x):
|
| 534 |
+
B, C, H, W = x.size()
|
| 535 |
+
|
| 536 |
+
x_input = torch.chunk(self.conv_input(x), 2, dim=1) # B, C, H, W
|
| 537 |
+
|
| 538 |
+
feature_h = x_input[0].permute(0, 2, 1, 3).contiguous() # B, H, C//2, W
|
| 539 |
+
feature_h_score = self.avg_pool(feature_h) # B, H, C//2, 1
|
| 540 |
+
feature_h_score = feature_h_score.view(B, H, -1)
|
| 541 |
+
|
| 542 |
+
feature_v = x_input[1].permute(0, 3, 1, 2).contiguous() # B, W, C//2, H
|
| 543 |
+
feature_v_score = self.avg_pool(feature_v) # B, W, C//2, 1
|
| 544 |
+
feature_v_score = feature_v_score.view(B, W, -1)
|
| 545 |
+
|
| 546 |
+
if H == W:
|
| 547 |
+
feature_score = torch.cat((feature_h_score, feature_v_score), dim=0) # B * 2, W or H, C//2
|
| 548 |
+
scan_score = self.mamba(feature_score)
|
| 549 |
+
scan_score = torch.chunk(scan_score, 2, dim=0)
|
| 550 |
+
scan_score_h = scan_score[0]
|
| 551 |
+
scan_score_v = scan_score[1]
|
| 552 |
+
else:
|
| 553 |
+
scan_score_h = self.mamba(feature_h_score)
|
| 554 |
+
scan_score_v = self.mamba(feature_v_score)
|
| 555 |
+
|
| 556 |
+
scan_score_h = self.sigmoid(scan_score_h)
|
| 557 |
+
scan_score_v = self.sigmoid(scan_score_v)
|
| 558 |
+
feature_h = feature_h*scan_score_h[:,:,:,None]
|
| 559 |
+
feature_v = feature_v*scan_score_v[:,:,:,None]
|
| 560 |
+
feature_h = feature_h.view(B, H, C//2, W).permute(0, 2, 1, 3).contiguous()
|
| 561 |
+
feature_v = feature_v.view(B, W, C//2, H).permute(0, 2, 3, 1).contiguous()
|
| 562 |
+
output = self.fuse_out(torch.cat((feature_h, feature_v), dim=1))
|
| 563 |
+
|
| 564 |
+
return output
|
| 565 |
+
|
| 566 |
+
def forward(self, x):
|
| 567 |
+
x = self.proj_in(x)
|
| 568 |
+
x, x_ = self.dwconv(x).chunk(2, dim=1)
|
| 569 |
+
x = self.forward_core(x)
|
| 570 |
+
x = F.silu(x_) * x
|
| 571 |
+
x = self.proj_out(x)
|
| 572 |
+
return x
|
| 573 |
+
|
| 574 |
+
|
| 575 |
+
##########################################################################
|
| 576 |
+
class Strip_VSSB(nn.Module):
|
| 577 |
+
def __init__(self, dim, vssm_expansion_factor, ffn_expansion_factor, bias=False, ssm=False, LayerNorm_type='WithBias'):
|
| 578 |
+
super(Strip_VSSB, self).__init__()
|
| 579 |
+
self.ssm = ssm
|
| 580 |
+
if self.ssm == True:
|
| 581 |
+
self.norm1_ssm = LayerNorm(dim, LayerNorm_type)
|
| 582 |
+
self.norm2_ssm = LayerNorm(dim, LayerNorm_type)
|
| 583 |
+
self.intra = Intra_VSSM(dim, vssm_expansion_factor, bias)
|
| 584 |
+
self.inter = Inter_VSSM(dim, vssm_expansion_factor, bias)
|
| 585 |
+
self.norm1_ffn = LayerNorm(dim, LayerNorm_type)
|
| 586 |
+
self.norm2_ffn = LayerNorm(dim, LayerNorm_type)
|
| 587 |
+
self.ffn1 = GDFN(dim, ffn_expansion_factor, bias)
|
| 588 |
+
self.ffn2 = GDFN(dim, ffn_expansion_factor, bias)
|
| 589 |
+
|
| 590 |
+
def forward(self, x):
|
| 591 |
+
if self.ssm == True:
|
| 592 |
+
x = x + self.intra(self.norm1_ssm(x))
|
| 593 |
+
x = x + self.ffn1(self.norm1_ffn(x))
|
| 594 |
+
if self.ssm == True:
|
| 595 |
+
x = x + self.inter(self.norm2_ssm(x))
|
| 596 |
+
x = x + self.ffn2(self.norm2_ffn(x))
|
| 597 |
+
|
| 598 |
+
return x
|
| 599 |
+
|
| 600 |
+
|
| 601 |
+
##########################################################################
|
| 602 |
+
##---------- Cross-level Feature Fusion by Adding Sigmoid(KL-Div) * Multi-Scale Feat -----------------------
|
| 603 |
+
class CLFF(nn.Module):
|
| 604 |
+
def __init__(self, dim, dim_n1, dim_n2, bias=False):
|
| 605 |
+
super(CLFF, self).__init__()
|
| 606 |
+
|
| 607 |
+
self.conv = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
|
| 608 |
+
self.conv_n1 = nn.Conv2d(dim_n1, dim, kernel_size=1, bias=bias)
|
| 609 |
+
self.conv_n2 = nn.Conv2d(dim_n2, dim, kernel_size=1, bias=bias)
|
| 610 |
+
self.fuse_out1 = nn.Conv2d(dim*2, dim, kernel_size=1, bias=bias)
|
| 611 |
+
|
| 612 |
+
self.log_sigmoid = nn.LogSigmoid()
|
| 613 |
+
self.sigmoid = nn.Sigmoid()
|
| 614 |
+
|
| 615 |
+
def forward(self, x, n1, n2):
|
| 616 |
+
x_ = self.conv(x)
|
| 617 |
+
n1_ = self.conv_n1(n1)
|
| 618 |
+
n2_ = self.conv_n2(n2)
|
| 619 |
+
kl_n1 = F.kl_div(input=self.log_sigmoid(n1_), target=self.log_sigmoid(x_), log_target=True)
|
| 620 |
+
kl_n2 = F.kl_div(input=self.log_sigmoid(n2_), target=self.log_sigmoid(x_), log_target=True)
|
| 621 |
+
#g = self.sigmoid(x_)
|
| 622 |
+
g1 = self.sigmoid(kl_n1)
|
| 623 |
+
g2 = self.sigmoid(kl_n2)
|
| 624 |
+
#x = (1 + g) * x_ + (1 - g) * (g1 * n1_ + g2 * n2_)
|
| 625 |
+
x = self.fuse_out1(torch.cat((x_, g1 * n1_ + g2 * n2_), dim=1))
|
| 626 |
+
|
| 627 |
+
return x
|
| 628 |
+
|
| 629 |
+
##########################################################################
|
| 630 |
+
##---------- StripScanNet -----------------------
|
| 631 |
+
class XYScanNet(nn.Module):
|
| 632 |
+
def __init__(self,
|
| 633 |
+
inp_channels=3,
|
| 634 |
+
out_channels=3,
|
| 635 |
+
dim = 24, # 48, 72, 96, 120, 144, default: 72
|
| 636 |
+
num_blocks = [3,3,6],
|
| 637 |
+
vssm_expansion_factor = 1, # 1 or 2
|
| 638 |
+
ffn_expansion_factor = 1, # 1 or 3
|
| 639 |
+
bias = False,
|
| 640 |
+
LayerNorm_type = 'WithBias', ## Other option 'BiasFree'
|
| 641 |
+
):
|
| 642 |
+
|
| 643 |
+
super(XYScanNet, self).__init__()
|
| 644 |
+
|
| 645 |
+
self.patch_embed = OverlapPatchEmbed(inp_channels, dim)
|
| 646 |
+
|
| 647 |
+
self.encoder_level1 = nn.Sequential(*[Strip_VSSB(dim=dim, vssm_expansion_factor=vssm_expansion_factor, ffn_expansion_factor = ffn_expansion_factor,
|
| 648 |
+
bias=bias, ssm=False, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])])
|
| 649 |
+
|
| 650 |
+
self.down1_2 = Downsample(dim) ## From Level 1 to Level 2
|
| 651 |
+
self.encoder_level2 = nn.Sequential(*[Strip_VSSB(dim=int(dim*2**1), vssm_expansion_factor=vssm_expansion_factor, ffn_expansion_factor = ffn_expansion_factor,
|
| 652 |
+
bias=bias, ssm=False, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[1])])
|
| 653 |
+
|
| 654 |
+
self.down2_3 = Downsample(int(dim*2**1)) ## From Level 2 to Level 3
|
| 655 |
+
self.encoder_level3 = nn.Sequential(*[Strip_VSSB(dim=int(dim*2**2), vssm_expansion_factor=vssm_expansion_factor, ffn_expansion_factor = ffn_expansion_factor,
|
| 656 |
+
bias=bias, ssm=False, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[2])])
|
| 657 |
+
|
| 658 |
+
self.decoder_level3 = nn.Sequential(*[Strip_VSSB(dim=int(dim*2**2), vssm_expansion_factor=vssm_expansion_factor, ffn_expansion_factor = ffn_expansion_factor,
|
| 659 |
+
bias=bias, ssm=True, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[2])])
|
| 660 |
+
|
| 661 |
+
self.up3_2 = Upsample(int(dim*2**2)) ## From Level 3 to Level 2
|
| 662 |
+
self.clff_level2 = CLFF(int(dim*2**1), dim_n1=int(dim*2**0), dim_n2=(dim*2**2), bias=bias)
|
| 663 |
+
self.reduce_chan_level2 = nn.Conv2d(int(dim*2**2), int(dim*2**1), kernel_size=1, bias=bias)
|
| 664 |
+
self.decoder_level2 = nn.Sequential(*[Strip_VSSB(dim=int(dim*2**1), vssm_expansion_factor=vssm_expansion_factor, ffn_expansion_factor = ffn_expansion_factor,
|
| 665 |
+
bias=bias, ssm=True, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[1])])
|
| 666 |
+
|
| 667 |
+
self.up2_1 = Upsample(int(dim*2**1)) ## From Level 2 to Level 1
|
| 668 |
+
self.clff_level1 = CLFF(int(dim*2**0), dim_n1=int(dim*2**1), dim_n2=(dim*2**2), bias=bias)
|
| 669 |
+
self.reduce_chan_level1 = nn.Conv2d(int(dim*2**1), int(dim*2**0), kernel_size=1, bias=bias)
|
| 670 |
+
self.decoder_level1 = nn.Sequential(*[Strip_VSSB(dim=int(dim*2**0), vssm_expansion_factor=vssm_expansion_factor, ffn_expansion_factor = ffn_expansion_factor,
|
| 671 |
+
bias=bias, ssm=True, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])])
|
| 672 |
+
|
| 673 |
+
# self.refinement = nn.Sequential(*[Strip_VSSB(dim=int(dim*2**0), expansion_factor=expansion_factor, bias=bias, ssm=True, LayerNorm_type=LayerNorm_type) for i in range(num_refinement_blocks)])
|
| 674 |
+
|
| 675 |
+
self.output = nn.Conv2d(int(dim*2**0), out_channels, kernel_size=3, stride=1, padding=1, bias=bias)
|
| 676 |
+
|
| 677 |
+
def forward(self, inp_img):
|
| 678 |
+
|
| 679 |
+
# Encoder
|
| 680 |
+
inp_enc_level1 = self.patch_embed(inp_img)
|
| 681 |
+
out_enc_level1 = self.encoder_level1(inp_enc_level1)
|
| 682 |
+
out_enc_level1_2 = F.interpolate(out_enc_level1, scale_factor=0.5) # dim*2, lvl1 down-scaled to lvl2
|
| 683 |
+
|
| 684 |
+
inp_enc_level2 = self.down1_2(out_enc_level1)
|
| 685 |
+
out_enc_level2 = self.encoder_level2(inp_enc_level2)
|
| 686 |
+
out_enc_level2_1 = F.interpolate(out_enc_level2, scale_factor=2) # dim*2, lvl2 up-scaled to lvl1
|
| 687 |
+
|
| 688 |
+
inp_enc_level3 = self.down2_3(out_enc_level2)
|
| 689 |
+
out_enc_level3 = self.encoder_level3(inp_enc_level3)
|
| 690 |
+
out_enc_level3_2 = F.interpolate(out_enc_level3, scale_factor=2) # dim*2**2, lvl3 up-scaled to lvl2 (lvl3->lvl2)
|
| 691 |
+
out_enc_level3_1 = F.interpolate(out_enc_level3_2, scale_factor=2) # dim*2**2, lvl3 up-scaled to lvl1 (lvl3->lvl2->lvl1)
|
| 692 |
+
|
| 693 |
+
out_enc_level1 = self.clff_level1(out_enc_level1, out_enc_level2_1, out_enc_level3_1)
|
| 694 |
+
out_enc_level2 = self.clff_level2(out_enc_level2, out_enc_level1_2, out_enc_level3_2)
|
| 695 |
+
|
| 696 |
+
# Decoder
|
| 697 |
+
out_dec_level3_decomp1 = self.decoder_level3(out_enc_level3)
|
| 698 |
+
|
| 699 |
+
inp_dec_level2_decomp1 = self.up3_2(out_dec_level3_decomp1)
|
| 700 |
+
inp_dec_level2_decomp1 = self.reduce_chan_level2(torch.cat((inp_dec_level2_decomp1, out_enc_level2), dim=1))
|
| 701 |
+
out_dec_level2_decomp1 = self.decoder_level2(inp_dec_level2_decomp1)
|
| 702 |
+
|
| 703 |
+
inp_dec_level1_decomp1 = self.up2_1(out_dec_level2_decomp1)
|
| 704 |
+
inp_dec_level1_decomp1 = self.reduce_chan_level1(torch.cat((inp_dec_level1_decomp1, out_enc_level1), dim=1))
|
| 705 |
+
out_dec_level1_decomp1 = self.decoder_level1(inp_dec_level1_decomp1)
|
| 706 |
+
|
| 707 |
+
out_dec_level1_decomp1 = self.output(out_dec_level1_decomp1)
|
| 708 |
+
|
| 709 |
+
out_dec_level1 = out_dec_level1_decomp1 + inp_img
|
| 710 |
+
|
| 711 |
+
|
| 712 |
+
return out_dec_level1, out_dec_level1_decomp1, None
|
| 713 |
+
|
| 714 |
+
#"""
|
| 715 |
+
import time
|
| 716 |
+
start_time = time.time()
|
| 717 |
+
inp = torch.randn(1, 3, 512, 512).cuda()#.to(dtype=torch.float16)
|
| 718 |
+
model = XYScanNet().cuda()#.to(dtype=torch.float16)
|
| 719 |
+
out = model(inp)[0]
|
| 720 |
+
print(out.shape)
|
| 721 |
+
print("--- %s seconds ---" % (time.time() - start_time))
|
| 722 |
+
pytorch_total_params = sum(p.numel() for p in model.parameters())
|
| 723 |
+
print("--- {num} parameters ---".format(num = pytorch_total_params))
|
| 724 |
+
pytorch_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 725 |
+
print("--- {num} trainable parameters ---".format(num = pytorch_trainable_params))
|
| 726 |
+
gpu_memmem_usage_bytes = torch.cuda.max_memory_allocated()
|
| 727 |
+
print(gpu_memmem_usage_bytes / 1024 / 1024 / 1024) # 64: 0.61 128: 2.21 256: 8.56; 512: 33.45
|
| 728 |
+
#"""
|
| 729 |
+
"""
|
| 730 |
+
import torch
|
| 731 |
+
from ptflops import get_model_complexity_info
|
| 732 |
+
|
| 733 |
+
with torch.cuda.device(0):
|
| 734 |
+
net = model
|
| 735 |
+
macs, params = get_model_complexity_info(net, (3, 256, 256), as_strings=True,
|
| 736 |
+
print_per_layer_stat=True, verbose=True)
|
| 737 |
+
print('{:<30} {:<8}'.format('Computational complexity: ', macs)) # 31.97 GMac
|
| 738 |
+
print('{:<30} {:<8}'.format('Number of parameters: ', params)) # 8.37 M
|
| 739 |
+
"""
|
| 740 |
+
"""
|
| 741 |
+
import time
|
| 742 |
+
start_time = time.time()
|
| 743 |
+
inp = torch.randn(1, 128, 64, 64).cuda()#.to(dtype=torch.float16)
|
| 744 |
+
model = Strip_VSSB(dim=128, expansion_factor=1).cuda()#.to(dtype=torch.float16)
|
| 745 |
+
out = model(inp)
|
| 746 |
+
print(out.shape)
|
| 747 |
+
print("--- %s seconds ---" % (time.time() - start_time))
|
| 748 |
+
pytorch_total_params = sum(p.numel() for p in model.parameters())
|
| 749 |
+
print("--- {num} parameters ---".format(num = pytorch_total_params))
|
| 750 |
+
pytorch_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 751 |
+
print("--- {num} trainable parameters ---".format(num = pytorch_trainable_params))
|
| 752 |
+
gpu_memmem_usage_bytes = torch.cuda.max_memory_allocated()
|
| 753 |
+
print(gpu_memmem_usage_bytes / 1024 / 1024 / 1024) # 128: 0.16 256: 0.24 512: 0.65
|
| 754 |
+
"""
|
out/Results.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
testing results are created in this folder
|
predict_GoPro_test_results.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import print_function
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
import cv2
|
| 5 |
+
import yaml
|
| 6 |
+
import os
|
| 7 |
+
from torch.autograd import Variable
|
| 8 |
+
from models.networks import get_generator
|
| 9 |
+
import torchvision
|
| 10 |
+
import time
|
| 11 |
+
import argparse
|
| 12 |
+
import torch.nn.functional as F
|
| 13 |
+
|
| 14 |
+
def get_args():
|
| 15 |
+
parser = argparse.ArgumentParser('Test an image')
|
| 16 |
+
parser.add_argument('--job_name', default='xyscannet',
|
| 17 |
+
type=str, help='current job s name')
|
| 18 |
+
return parser.parse_args()
|
| 19 |
+
|
| 20 |
+
def print_max_gpu_usage():
|
| 21 |
+
"""Prints the maximum GPU memory usage in GB."""
|
| 22 |
+
max_memory = torch.cuda.max_memory_allocated()
|
| 23 |
+
max_memory_in_gb = max_memory / (1024 ** 3) # Convert bytes to GB
|
| 24 |
+
print(f"Maximum GPU memory usage during test: {max_memory_in_gb:.2f} GB")
|
| 25 |
+
|
| 26 |
+
if __name__ == '__main__':
|
| 27 |
+
# optionally reset gpu
|
| 28 |
+
#torch.cuda.reset_max_memory_allocated()
|
| 29 |
+
args = get_args()
|
| 30 |
+
#with open(os.path.join('config/', args.job_name, 'config_stage2.yaml'), 'r') as cfg: # change the CFG name to test different models: pretrained, gopro, refined, stage1, stage2
|
| 31 |
+
# config = yaml.safe_load(cfg)
|
| 32 |
+
with open(os.path.join('config/', args.job_name, 'config_stage2.yaml'), 'r') as cfg: # change the CFG name to test different models: pretrained, gopro, refined, stage1, stage2
|
| 33 |
+
config = yaml.safe_load(cfg)
|
| 34 |
+
blur_path = '/mnt/g/RESEARCH/PHD/Motion_Deblurred/datasets/GOPRO_/test/testA'
|
| 35 |
+
out_path = os.path.join('results', args.job_name, 'images')
|
| 36 |
+
weights_path = os.path.join('results', args.job_name, 'models', 'best_{}.pth'.format(config['experiment_desc'])) # change the model name to test different phases: final/best
|
| 37 |
+
if not os.path.isdir(out_path):
|
| 38 |
+
os.mkdir(out_path)
|
| 39 |
+
model = get_generator(config['model'])
|
| 40 |
+
model.load_state_dict(torch.load(weights_path))
|
| 41 |
+
model = model.cuda()
|
| 42 |
+
#model.eval()
|
| 43 |
+
|
| 44 |
+
test_time = 0
|
| 45 |
+
iteration = 0
|
| 46 |
+
total_image_number = 1111
|
| 47 |
+
|
| 48 |
+
# warm-up
|
| 49 |
+
warm_up = 0
|
| 50 |
+
print('Hardware warm-up')
|
| 51 |
+
for file in os.listdir(blur_path):
|
| 52 |
+
for img_name in os.listdir(blur_path + '/' + file):
|
| 53 |
+
warm_up += 1
|
| 54 |
+
img = cv2.imread(blur_path + '/' + file + '/' + img_name)
|
| 55 |
+
img_tensor = torch.from_numpy(np.transpose(img / 255, (2, 0, 1)).astype('float32')) - 0.5
|
| 56 |
+
with torch.no_grad():
|
| 57 |
+
img_tensor = Variable(img_tensor.unsqueeze(0)).cuda()
|
| 58 |
+
result_image, decomp1, decomp2 = model(img_tensor)
|
| 59 |
+
#result_image = model(img_tensor)
|
| 60 |
+
if warm_up == 20:
|
| 61 |
+
break
|
| 62 |
+
break
|
| 63 |
+
|
| 64 |
+
for file in os.listdir(blur_path):
|
| 65 |
+
if not os.path.isdir(out_path + '/' + file):
|
| 66 |
+
os.mkdir(out_path + '/' + file)
|
| 67 |
+
for img_name in os.listdir(blur_path + '/' + file):
|
| 68 |
+
img = cv2.imread(blur_path + '/' + file + '/' + img_name)
|
| 69 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
| 70 |
+
img_tensor = torch.from_numpy(np.transpose(img / 255, (2, 0, 1)).astype('float32')) - 0.5
|
| 71 |
+
with torch.no_grad():
|
| 72 |
+
iteration += 1
|
| 73 |
+
img_tensor = Variable(img_tensor.unsqueeze(0)).cuda()
|
| 74 |
+
|
| 75 |
+
start = time.time()
|
| 76 |
+
result_image, decomp1, decomp2 = model(img_tensor)
|
| 77 |
+
#result_image = model(img_tensor)
|
| 78 |
+
stop = time.time()
|
| 79 |
+
print('Image:{}/{}, CNN Runtime:{:.4f}'.format(iteration, total_image_number, (stop - start)))
|
| 80 |
+
test_time += stop - start
|
| 81 |
+
print('Average Runtime:{:.4f}'.format(test_time / float(iteration)))
|
| 82 |
+
result_image = result_image + 0.5
|
| 83 |
+
out_file_name = out_path + '/' + file + '/' + img_name
|
| 84 |
+
# optionally save image
|
| 85 |
+
torchvision.utils.save_image(result_image, out_file_name)
|
| 86 |
+
|
| 87 |
+
# optionally print gpu usage
|
| 88 |
+
#print_max_gpu_usage()
|
| 89 |
+
#torch.cuda.reset_max_memory_allocated()
|
predict_HIDE_test_results.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import print_function
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
import cv2
|
| 5 |
+
import yaml
|
| 6 |
+
import os
|
| 7 |
+
from torch.autograd import Variable
|
| 8 |
+
from models.networks import get_generator
|
| 9 |
+
import torchvision
|
| 10 |
+
import time
|
| 11 |
+
import argparse
|
| 12 |
+
|
| 13 |
+
def get_args():
|
| 14 |
+
parser = argparse.ArgumentParser('Test an image')
|
| 15 |
+
parser.add_argument('--job_name', default='xyscannet',
|
| 16 |
+
type=str, help='current job s name')
|
| 17 |
+
return parser.parse_args()
|
| 18 |
+
|
| 19 |
+
if __name__ == '__main__':
|
| 20 |
+
args = get_args()
|
| 21 |
+
with open(os.path.join('config/', args.job_name, 'config_stage2.yaml')) as cfg: # change the yaml file to config_pretrained if ablation
|
| 22 |
+
#with open(os.path.join('config/', args.job_name, 'config_stage2.yaml')) as cfg: # change the yaml file to config_pretrained if ablation
|
| 23 |
+
config = yaml.safe_load(cfg)
|
| 24 |
+
blur_path = '/scratch/user/hanzhou1996/datasets/deblur/HIDE/test/testA/'
|
| 25 |
+
out_path = os.path.join('results', args.job_name, 'images_hide')
|
| 26 |
+
weights_path = os.path.join('results', args.job_name, 'models', 'best_XYScanNet_stage2.pth') # change the model name to test different phases: final/best
|
| 27 |
+
if not os.path.isdir(out_path):
|
| 28 |
+
os.mkdir(out_path)
|
| 29 |
+
model = get_generator(config['model'])
|
| 30 |
+
model.load_state_dict(torch.load(weights_path))
|
| 31 |
+
model = model.cuda()
|
| 32 |
+
test_time = 0
|
| 33 |
+
iteration = 0
|
| 34 |
+
total_image_number = 2025
|
| 35 |
+
|
| 36 |
+
# warm up
|
| 37 |
+
warm_up = 0
|
| 38 |
+
print('Hardware warm-up')
|
| 39 |
+
for img_name in os.listdir(blur_path):
|
| 40 |
+
warm_up += 1
|
| 41 |
+
img = cv2.imread(blur_path + '/' + img_name)
|
| 42 |
+
img_tensor = torch.from_numpy(np.transpose(img / 255, (2, 0, 1)).astype('float32')) - 0.5
|
| 43 |
+
with torch.no_grad():
|
| 44 |
+
img_tensor = Variable(img_tensor.unsqueeze(0)).cuda()
|
| 45 |
+
result_image, decomp1, decomp2 = model(img_tensor)
|
| 46 |
+
#result_image = model(img_tensor)
|
| 47 |
+
if warm_up == 20:
|
| 48 |
+
break
|
| 49 |
+
break
|
| 50 |
+
|
| 51 |
+
for img_name in os.listdir(blur_path):
|
| 52 |
+
img = cv2.imread(blur_path + '/' + img_name)
|
| 53 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
| 54 |
+
img_tensor = torch.from_numpy(np.transpose(img / 255, (2, 0, 1)).astype('float32')) - 0.5
|
| 55 |
+
with torch.no_grad():
|
| 56 |
+
iteration += 1
|
| 57 |
+
img_tensor = Variable(img_tensor.unsqueeze(0)).cuda()
|
| 58 |
+
|
| 59 |
+
start = time.time()
|
| 60 |
+
result_image, decomp1, decomp2 = model(img_tensor)
|
| 61 |
+
#result_image = model(img_tensor)
|
| 62 |
+
stop = time.time()
|
| 63 |
+
|
| 64 |
+
print('Image:{}/{}, CNN Runtime:{:.4f}'.format(iteration, total_image_number, (stop - start)))
|
| 65 |
+
test_time += stop - start
|
| 66 |
+
print('Average Runtime:{:.4f}'.format(test_time / float(iteration)))
|
| 67 |
+
result_image = result_image + 0.5
|
| 68 |
+
out_file_name = out_path + '/' + img_name
|
| 69 |
+
torchvision.utils.save_image(result_image, out_file_name)
|
predict_RWBI_test_results.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import print_function
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
import cv2
|
| 5 |
+
import yaml
|
| 6 |
+
import os
|
| 7 |
+
from torch.autograd import Variable
|
| 8 |
+
from models.networks import get_generator
|
| 9 |
+
import torchvision
|
| 10 |
+
import time
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
import argparse
|
| 13 |
+
|
| 14 |
+
def get_args():
|
| 15 |
+
parser = argparse.ArgumentParser('Test an image')
|
| 16 |
+
parser.add_argument('--job_name', default='xyscannet',
|
| 17 |
+
type=str, help='current job s name')
|
| 18 |
+
return parser.parse_args()
|
| 19 |
+
|
| 20 |
+
if __name__ == '__main__':
|
| 21 |
+
args = get_args()
|
| 22 |
+
with open(os.path.join('config/', args.job_name, 'config_pretrained.yaml')) as cfg: # change the yaml file to config_pretrained if ablation
|
| 23 |
+
#with open(os.path.join('config/', args.job_name, 'config_stage2.yaml')) as cfg: # change the yaml file to config_pretrained if ablation
|
| 24 |
+
config = yaml.safe_load(cfg)
|
| 25 |
+
blur_path = '/scratch/user/hanzhou1996/datasets/deblur/RWBI/test/testA/'
|
| 26 |
+
out_path = os.path.join('results', args.job_name, 'images_rwbi')
|
| 27 |
+
weights_path = os.path.join('results', args.job_name, 'models', 'best_XYScanNet_stage2.pth') # change the model name to test different phases: final/best
|
| 28 |
+
if not os.path.isdir(out_path):
|
| 29 |
+
os.mkdir(out_path)
|
| 30 |
+
model = get_generator(config['model'])
|
| 31 |
+
model.load_state_dict(torch.load(weights_path))
|
| 32 |
+
model = model.cuda()
|
| 33 |
+
test_time = 0
|
| 34 |
+
iteration = 0
|
| 35 |
+
total_image_number = 1000
|
| 36 |
+
|
| 37 |
+
# warm up
|
| 38 |
+
warm_up = 0
|
| 39 |
+
print('Hardware warm-up')
|
| 40 |
+
for img_name in os.listdir(blur_path):
|
| 41 |
+
warm_up += 1
|
| 42 |
+
img = cv2.imread(blur_path + '/' + img_name)
|
| 43 |
+
img_tensor = torch.from_numpy(np.transpose(img / 255, (2, 0, 1)).astype('float32')) - 0.5
|
| 44 |
+
with torch.no_grad():
|
| 45 |
+
img_tensor = Variable(img_tensor.unsqueeze(0)).cuda()
|
| 46 |
+
factor = 8
|
| 47 |
+
h, w = img_tensor.shape[2], img_tensor.shape[3]
|
| 48 |
+
H, W = ((h + factor) // factor) * factor, ((w + factor) // factor) * factor
|
| 49 |
+
padh = H - h if h % factor != 0 else 0
|
| 50 |
+
padw = W - w if w % factor != 0 else 0
|
| 51 |
+
img_tensor = F.pad(img_tensor, (0, padw, 0, padh), 'reflect')
|
| 52 |
+
H, W = img_tensor.shape[2], img_tensor.shape[3]
|
| 53 |
+
|
| 54 |
+
result_image, decomp1, decomp2 = model(img_tensor)
|
| 55 |
+
if warm_up == 20:
|
| 56 |
+
break
|
| 57 |
+
|
| 58 |
+
for file in os.listdir(blur_path):
|
| 59 |
+
if not os.path.isdir(out_path):
|
| 60 |
+
os.mkdir(out_path)
|
| 61 |
+
img = cv2.imread(blur_path + '/' + file)
|
| 62 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
| 63 |
+
img_tensor = torch.from_numpy(np.transpose(img / 255, (2, 0, 1)).astype('float32')) - 0.5
|
| 64 |
+
with torch.no_grad():
|
| 65 |
+
iteration += 1
|
| 66 |
+
img_tensor = Variable(img_tensor.unsqueeze(0)).cuda()
|
| 67 |
+
|
| 68 |
+
factor = 8
|
| 69 |
+
h, w = img_tensor.shape[2], img_tensor.shape[3]
|
| 70 |
+
H, W = ((h + factor) // factor) * factor, ((w + factor) // factor) * factor
|
| 71 |
+
padh = H - h if h % factor != 0 else 0
|
| 72 |
+
padw = W - w if w % factor != 0 else 0
|
| 73 |
+
img_tensor = F.pad(img_tensor, (0, padw, 0, padh), 'reflect')
|
| 74 |
+
H, W = img_tensor.shape[2], img_tensor.shape[3]
|
| 75 |
+
|
| 76 |
+
#with torch.autocast(device_type='cuda', dtype=torch.float16):
|
| 77 |
+
start = time.time()
|
| 78 |
+
result_image, decomp1, decomp2 = model(img_tensor)
|
| 79 |
+
stop = time.time()
|
| 80 |
+
|
| 81 |
+
result_image = result_image[:, :, :h, :w]
|
| 82 |
+
|
| 83 |
+
print('Image:{}/{}, CNN Runtime:{:.4f}'.format(iteration, total_image_number, (stop - start)))
|
| 84 |
+
test_time += stop - start
|
| 85 |
+
print('Average Runtime:{:.4f}'.format(test_time / float(iteration)))
|
| 86 |
+
result_image = result_image + 0.5
|
| 87 |
+
out_file_name = out_path + '/' + file
|
| 88 |
+
torchvision.utils.save_image(result_image, out_file_name)
|
predict_RealBlur_J_test_results.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import print_function
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
import cv2
|
| 5 |
+
import yaml
|
| 6 |
+
import os
|
| 7 |
+
from torch.autograd import Variable
|
| 8 |
+
from models.networks import get_generator
|
| 9 |
+
import torchvision
|
| 10 |
+
import time
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
import argparse
|
| 13 |
+
|
| 14 |
+
def get_args():
|
| 15 |
+
parser = argparse.ArgumentParser('Test an image')
|
| 16 |
+
parser.add_argument('--job_name', default='fsformer_without_fs',
|
| 17 |
+
type=str, help='current job s name')
|
| 18 |
+
return parser.parse_args()
|
| 19 |
+
|
| 20 |
+
if __name__ == '__main__':
|
| 21 |
+
args = get_args()
|
| 22 |
+
with open(os.path.join('config/', args.job_name, 'config_stage2.yaml')) as cfg:
|
| 23 |
+
#with open(os.path.join('config/', args.job_name, 'config_pretrained.yaml')) as cfg:
|
| 24 |
+
config = yaml.safe_load(cfg)
|
| 25 |
+
blur_path = '/scratch/user/hanzhou1996/datasets/deblur/RealBlur_J/test/testA'
|
| 26 |
+
out_path = os.path.join('results', args.job_name, 'images_realj')
|
| 27 |
+
weights_path = os.path.join('results', args.job_name, 'models', 'final_XYScanNet_stage2.pth') # change the model name to test different phases: final/best final_StripMamba_pretrained.pth
|
| 28 |
+
if not os.path.isdir(out_path):
|
| 29 |
+
os.mkdir(out_path)
|
| 30 |
+
model = get_generator(config['model'])
|
| 31 |
+
model.load_state_dict(torch.load(weights_path))
|
| 32 |
+
model = model.cuda()
|
| 33 |
+
test_time = 0
|
| 34 |
+
iteration = 0
|
| 35 |
+
total_image_number = 980
|
| 36 |
+
|
| 37 |
+
# warm up
|
| 38 |
+
warm_up = 0
|
| 39 |
+
print('Hardware warm-up')
|
| 40 |
+
for file in os.listdir(blur_path):
|
| 41 |
+
#if not os.path.isdir(out_path + '/' + file):
|
| 42 |
+
# os.mkdir(out_path + '/' + file)
|
| 43 |
+
img_name = file
|
| 44 |
+
# for img_name in os.listdir(blur_path + '/' + file):
|
| 45 |
+
warm_up += 1
|
| 46 |
+
img = cv2.imread(blur_path + '/' + file)
|
| 47 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
| 48 |
+
img_tensor = torch.from_numpy(np.transpose(img / 255, (2, 0, 1)).astype('float32')) - 0.5
|
| 49 |
+
with torch.no_grad():
|
| 50 |
+
img_tensor = Variable(img_tensor.unsqueeze(0)).cuda()
|
| 51 |
+
factor = 8
|
| 52 |
+
h, w = img_tensor.shape[2], img_tensor.shape[3]
|
| 53 |
+
H, W = ((h + factor) // factor) * factor, ((w + factor) // factor) * factor
|
| 54 |
+
padh = H - h if h % factor != 0 else 0
|
| 55 |
+
padw = W - w if w % factor != 0 else 0
|
| 56 |
+
img_tensor = F.pad(img_tensor, (0, padw, 0, padh), 'reflect')
|
| 57 |
+
result_image, decomp1, decomp2 = model(img_tensor)
|
| 58 |
+
#result_image = model(img_tensor)
|
| 59 |
+
if warm_up == 20:
|
| 60 |
+
break
|
| 61 |
+
break
|
| 62 |
+
|
| 63 |
+
for file in os.listdir(blur_path):
|
| 64 |
+
#if not os.path.isdir(out_path + '/' + file):
|
| 65 |
+
# os.mkdir(out_path + '/' + file)
|
| 66 |
+
img_name = file
|
| 67 |
+
# for img_name in os.listdir(blur_path + '/' + file):
|
| 68 |
+
img = cv2.imread(blur_path + '/' + file)
|
| 69 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
| 70 |
+
img_tensor = torch.from_numpy(np.transpose(img / 255, (2, 0, 1)).astype('float32')) - 0.5
|
| 71 |
+
with torch.no_grad():
|
| 72 |
+
iteration += 1
|
| 73 |
+
img_tensor = Variable(img_tensor.unsqueeze(0)).cuda()
|
| 74 |
+
|
| 75 |
+
factor = 8
|
| 76 |
+
h, w = img_tensor.shape[2], img_tensor.shape[3]
|
| 77 |
+
H, W = ((h + factor) // factor) * factor, ((w + factor) // factor) * factor
|
| 78 |
+
padh = H - h if h % factor != 0 else 0
|
| 79 |
+
padw = W - w if w % factor != 0 else 0
|
| 80 |
+
img_tensor = F.pad(img_tensor, (0, padw, 0, padh), 'reflect')
|
| 81 |
+
H, W = img_tensor.shape[2], img_tensor.shape[3]
|
| 82 |
+
|
| 83 |
+
start = time.time()
|
| 84 |
+
_output, decomp1, decomp2 = model(img_tensor)
|
| 85 |
+
#_output = model(img_tensor)
|
| 86 |
+
stop = time.time()
|
| 87 |
+
|
| 88 |
+
result_image = _output[:, :, :h, :w]
|
| 89 |
+
result_image = torch.clamp(result_image, -0.5, 0.5)
|
| 90 |
+
result_image = result_image + 0.5
|
| 91 |
+
|
| 92 |
+
test_time += stop - start
|
| 93 |
+
print('Image:{}/{}, CNN Runtime:{:.4f}'.format(iteration, total_image_number, (stop - start)))
|
| 94 |
+
print('Average Runtime:{:.4f}'.format(test_time / float(iteration)))
|
| 95 |
+
out_file_name = out_path + '/' + img_name
|
| 96 |
+
torchvision.utils.save_image(result_image, out_file_name)
|
| 97 |
+
|
predict_RealBlur_R_test_results.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import print_function
|
| 2 |
+
import argparse
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
import cv2
|
| 6 |
+
import yaml
|
| 7 |
+
import os
|
| 8 |
+
from torch.autograd import Variable
|
| 9 |
+
from models.networks import get_generator
|
| 10 |
+
import torchvision
|
| 11 |
+
import time
|
| 12 |
+
import torch.nn.functional as F
|
| 13 |
+
|
| 14 |
+
def get_args():
|
| 15 |
+
parser = argparse.ArgumentParser('Test an image')
|
| 16 |
+
parser.add_argument('--job_name', default='xyscannet',
|
| 17 |
+
type=str, help='current job s name')
|
| 18 |
+
return parser.parse_args()
|
| 19 |
+
|
| 20 |
+
if __name__ == '__main__':
|
| 21 |
+
args = get_args()
|
| 22 |
+
with open(os.path.join('config/', args.job_name, 'config_stage2.yaml')) as cfg:
|
| 23 |
+
config = yaml.safe_load(cfg)
|
| 24 |
+
blur_path = '/scratch/user/hanzhou1996/datasets/deblur/RealBlur_R/test/testA'
|
| 25 |
+
out_path = os.path.join('results', args.job_name, 'images_realr')
|
| 26 |
+
weights_path = os.path.join('results', args.job_name, 'models', 'final_XYScanNet_stage2.pth') # change the model name to test different phases: final/best
|
| 27 |
+
if not os.path.isdir(out_path):
|
| 28 |
+
os.mkdir(out_path)
|
| 29 |
+
model = get_generator(config['model'])
|
| 30 |
+
model.load_state_dict(torch.load(weights_path))
|
| 31 |
+
model = model.cuda()
|
| 32 |
+
test_time = 0
|
| 33 |
+
iteration = 0
|
| 34 |
+
total_image_number = 980
|
| 35 |
+
|
| 36 |
+
# warm up
|
| 37 |
+
warm_up = 0
|
| 38 |
+
print('Hardware warm-up')
|
| 39 |
+
for file in os.listdir(blur_path):
|
| 40 |
+
#if not os.path.isdir(out_path + '/' + file):
|
| 41 |
+
# os.mkdir(out_path + '/' + file)
|
| 42 |
+
#for img_name in os.listdir(blur_path + '/' + file):
|
| 43 |
+
img_name = file
|
| 44 |
+
warm_up += 1
|
| 45 |
+
img = cv2.imread(blur_path + '/' + file)
|
| 46 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
| 47 |
+
img_tensor = torch.from_numpy(np.transpose(img / 255, (2, 0, 1)).astype('float32')) - 0.5
|
| 48 |
+
with torch.no_grad():
|
| 49 |
+
img_tensor = Variable(img_tensor.unsqueeze(0)).cuda()
|
| 50 |
+
factor = 8
|
| 51 |
+
h, w = img_tensor.shape[2], img_tensor.shape[3]
|
| 52 |
+
H, W = ((h + factor) // factor) * factor, ((w + factor) // factor) * factor
|
| 53 |
+
padh = H - h if h % factor != 0 else 0
|
| 54 |
+
padw = W - w if w % factor != 0 else 0
|
| 55 |
+
img_tensor = F.pad(img_tensor, (0, padw, 0, padh), 'reflect')
|
| 56 |
+
result_image, decomp1, decomp2 = model(img_tensor)
|
| 57 |
+
#result_image = model(img_tensor)
|
| 58 |
+
if warm_up == 20:
|
| 59 |
+
break
|
| 60 |
+
break
|
| 61 |
+
|
| 62 |
+
for file in os.listdir(blur_path):
|
| 63 |
+
#if not os.path.isdir(out_path + '/' + file):
|
| 64 |
+
# os.mkdir(out_path + '/' + file)
|
| 65 |
+
#for img_name in os.listdir(blur_path + '/' + file):
|
| 66 |
+
img_name = file
|
| 67 |
+
img = cv2.imread(blur_path + '/' + file)
|
| 68 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
| 69 |
+
img_tensor = torch.from_numpy(np.transpose(img / 255, (2, 0, 1)).astype('float32')) - 0.5
|
| 70 |
+
with torch.no_grad():
|
| 71 |
+
iteration += 1
|
| 72 |
+
img_tensor = Variable(img_tensor.unsqueeze(0)).cuda()
|
| 73 |
+
|
| 74 |
+
factor = 8
|
| 75 |
+
h, w = img_tensor.shape[2], img_tensor.shape[3]
|
| 76 |
+
H, W = ((h + factor) // factor) * factor, ((w + factor) // factor) * factor
|
| 77 |
+
padh = H - h if h % factor != 0 else 0
|
| 78 |
+
padw = W - w if w % factor != 0 else 0
|
| 79 |
+
img_tensor = F.pad(img_tensor, (0, padw, 0, padh), 'reflect')
|
| 80 |
+
H, W = img_tensor.shape[2], img_tensor.shape[3]
|
| 81 |
+
|
| 82 |
+
start = time.time()
|
| 83 |
+
_output, decomp1, decomp2 = model(img_tensor)
|
| 84 |
+
#_output = model(img_tensor)
|
| 85 |
+
stop = time.time()
|
| 86 |
+
|
| 87 |
+
result_image = _output[:, :, :h, :w]
|
| 88 |
+
result_image = torch.clamp(result_image, -0.5, 0.5)
|
| 89 |
+
result_image = result_image + 0.5
|
| 90 |
+
|
| 91 |
+
test_time += stop - start
|
| 92 |
+
print('Image:{}/{}, CNN Runtime:{:.4f}'.format(iteration, total_image_number, (stop - start)))
|
| 93 |
+
print('Average Runtime:{:.4f}'.format(test_time / float(iteration)))
|
| 94 |
+
out_file_name = out_path + '/' + img_name
|
| 95 |
+
torchvision.utils.save_image(result_image, out_file_name)
|
| 96 |
+
|
requirements.txt
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
gradio==4.44.1
|
| 2 |
+
spaces
|
| 3 |
+
torch==2.1.2
|
| 4 |
+
torchvision==0.16.2
|
| 5 |
+
transformers==4.46.3
|
| 6 |
+
einops==0.8.1
|
| 7 |
+
PyYAML==6.0.2
|
| 8 |
+
opencv-python-headless==4.10.0.84
|
| 9 |
+
numpy==1.26.4
|
| 10 |
+
pillow==10.4.0
|
| 11 |
+
mamba-ssm==2.2.2
|
results/xyscannetp_gopro/models/best_XYScanNet_stage2.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:da56c2c8ccb0c7cfc86c81431e9b7c2681c2109ca056d56d9454ba4aeb6c07e0
|
| 3 |
+
size 254328477
|
schedulers.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
|
| 3 |
+
from torch.optim import lr_scheduler
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class WarmRestart(lr_scheduler.CosineAnnealingLR):
|
| 7 |
+
"""This class implements Stochastic Gradient Descent with Warm Restarts(SGDR): https://arxiv.org/abs/1608.03983.
|
| 8 |
+
|
| 9 |
+
Set the learning rate of each parameter group using a cosine annealing schedule, When last_epoch=-1, sets initial lr as lr.
|
| 10 |
+
This can't support scheduler.step(epoch). please keep epoch=None.
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
def __init__(self, optimizer, T_max=30, T_mult=1, eta_min=0, last_epoch=-1):
|
| 14 |
+
"""implements SGDR
|
| 15 |
+
|
| 16 |
+
Parameters:
|
| 17 |
+
----------
|
| 18 |
+
T_max : int
|
| 19 |
+
Maximum number of epochs.
|
| 20 |
+
T_mult : int
|
| 21 |
+
Multiplicative factor of T_max.
|
| 22 |
+
eta_min : int
|
| 23 |
+
Minimum learning rate. Default: 0.
|
| 24 |
+
last_epoch : int
|
| 25 |
+
The index of last epoch. Default: -1.
|
| 26 |
+
"""
|
| 27 |
+
self.T_mult = T_mult
|
| 28 |
+
super().__init__(optimizer, T_max, eta_min, last_epoch)
|
| 29 |
+
|
| 30 |
+
def get_lr(self):
|
| 31 |
+
if self.last_epoch == self.T_max:
|
| 32 |
+
self.last_epoch = 0
|
| 33 |
+
self.T_max *= self.T_mult
|
| 34 |
+
return [self.eta_min + (base_lr - self.eta_min) * (1 + math.cos(math.pi * self.last_epoch / self.T_max)) / 2 for
|
| 35 |
+
base_lr in self.base_lrs]
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class LinearDecay(lr_scheduler._LRScheduler):
|
| 39 |
+
"""This class implements LinearDecay
|
| 40 |
+
|
| 41 |
+
"""
|
| 42 |
+
|
| 43 |
+
def __init__(self, optimizer, num_epochs, start_epoch=0, min_lr=0, last_epoch=-1):
|
| 44 |
+
"""implements LinearDecay
|
| 45 |
+
|
| 46 |
+
Parameters:
|
| 47 |
+
----------
|
| 48 |
+
|
| 49 |
+
"""
|
| 50 |
+
self.num_epochs = num_epochs
|
| 51 |
+
self.start_epoch = start_epoch
|
| 52 |
+
self.min_lr = min_lr
|
| 53 |
+
super().__init__(optimizer, last_epoch)
|
| 54 |
+
|
| 55 |
+
def get_lr(self):
|
| 56 |
+
if self.last_epoch < self.start_epoch:
|
| 57 |
+
return self.base_lrs
|
| 58 |
+
return [base_lr - ((base_lr - self.min_lr) / self.num_epochs) * (self.last_epoch - self.start_epoch) for
|
| 59 |
+
base_lr in self.base_lrs]
|
train_XYScanNet_stage1.py
ADDED
|
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from functools import partial
|
| 3 |
+
import os
|
| 4 |
+
import cv2
|
| 5 |
+
import torch
|
| 6 |
+
import torch.optim as optim
|
| 7 |
+
import tqdm
|
| 8 |
+
import yaml
|
| 9 |
+
from joblib import cpu_count
|
| 10 |
+
from torch.utils.data import DataLoader
|
| 11 |
+
import random
|
| 12 |
+
from dataset import PairedDataset
|
| 13 |
+
from metric_counter import MetricCounter
|
| 14 |
+
|
| 15 |
+
from models.losses import get_loss
|
| 16 |
+
from models.models import get_model
|
| 17 |
+
from models.networks import get_nets
|
| 18 |
+
from util import util
|
| 19 |
+
import numpy as np
|
| 20 |
+
from torch.optim.lr_scheduler import CosineAnnealingLR
|
| 21 |
+
cv2.setNumThreads(0)
|
| 22 |
+
|
| 23 |
+
import argparse
|
| 24 |
+
parser = argparse.ArgumentParser(description='Image motion deblurring evaluation on GoPro/HIDE')
|
| 25 |
+
parser.add_argument('--job_name', default='xyscannet',
|
| 26 |
+
type=str, help='current job s name')
|
| 27 |
+
args = parser.parse_args()
|
| 28 |
+
|
| 29 |
+
class Trainer:
|
| 30 |
+
def __init__(self, config, train: DataLoader, val: DataLoader):
|
| 31 |
+
self.config = config
|
| 32 |
+
self.train_dataset = train
|
| 33 |
+
self.val_dataset = val
|
| 34 |
+
self.metric_counter = MetricCounter(config['experiment_desc'])
|
| 35 |
+
|
| 36 |
+
def train(self):
|
| 37 |
+
self._init_params()
|
| 38 |
+
start_epoch = 0
|
| 39 |
+
print("The current job is: ", args.job_name)
|
| 40 |
+
model_dir = os.path.join('results/', args.job_name, 'models')
|
| 41 |
+
util.mkdir(model_dir)
|
| 42 |
+
if os.path.exists(os.path.join(model_dir, 'last_XYScanNet_stage1.pth')):
|
| 43 |
+
print('resume learning')
|
| 44 |
+
training_state = (torch.load(os.path.join(model_dir, 'last_XYScanNet_stage1.pth')))
|
| 45 |
+
start_epoch = training_state['epoch'] + 1
|
| 46 |
+
new_weight = self.netG.state_dict()
|
| 47 |
+
new_weight.update(training_state['model_state'])
|
| 48 |
+
self.netG.load_state_dict(new_weight)
|
| 49 |
+
new_optimizer = self.optimizer_G.state_dict()
|
| 50 |
+
new_optimizer.update(training_state['optimizer_state'])
|
| 51 |
+
self.optimizer_G.load_state_dict(new_optimizer)
|
| 52 |
+
new_scheduler = self.scheduler_G.state_dict()
|
| 53 |
+
new_scheduler.update(training_state['scheduler_state'])
|
| 54 |
+
self.scheduler_G.load_state_dict(new_scheduler)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
for epoch in range(start_epoch, config['num_epochs']):
|
| 58 |
+
self._run_epoch(epoch)
|
| 59 |
+
if epoch % 30 == 0 or epoch == (config['num_epochs']-1):
|
| 60 |
+
self._validate(epoch)
|
| 61 |
+
self.scheduler_G.step()
|
| 62 |
+
|
| 63 |
+
scheduler_state = self.scheduler_G.state_dict()
|
| 64 |
+
training_state = {'epoch': epoch, 'model_state': self.netG.state_dict(),
|
| 65 |
+
'scheduler_state': scheduler_state, 'optimizer_state': self.optimizer_G.state_dict()}
|
| 66 |
+
if self.metric_counter.update_best_model():
|
| 67 |
+
torch.save(training_state['model_state'],
|
| 68 |
+
os.path.join(model_dir, 'best_{}.pth'.format(self.config['experiment_desc'])))
|
| 69 |
+
|
| 70 |
+
if epoch % 300 == 0:
|
| 71 |
+
torch.save(training_state,
|
| 72 |
+
os.path.join(model_dir, 'last_{}_{}.pth'.format(self.config['experiment_desc'], epoch)))
|
| 73 |
+
|
| 74 |
+
if epoch == (config['num_epochs']-1):
|
| 75 |
+
torch.save(training_state['model_state'],
|
| 76 |
+
os.path.join(model_dir, 'final_{}.pth'.format(self.config['experiment_desc'])))
|
| 77 |
+
|
| 78 |
+
torch.save(training_state,
|
| 79 |
+
os.path.join(model_dir, 'last_{}.pth'.format(self.config['experiment_desc'])))
|
| 80 |
+
|
| 81 |
+
logging.debug("Experiment Name: %s, Epoch: %d, Loss: %s" % (
|
| 82 |
+
self.config['experiment_desc'], epoch, self.metric_counter.loss_message()))
|
| 83 |
+
|
| 84 |
+
def _run_epoch(self, epoch):
|
| 85 |
+
self.metric_counter.clear()
|
| 86 |
+
for param_group in self.optimizer_G.param_groups:
|
| 87 |
+
lr = param_group['lr']
|
| 88 |
+
|
| 89 |
+
epoch_size = config.get('train_batches_per_epoch') or len(self.train_dataset)
|
| 90 |
+
tq = tqdm.tqdm(self.train_dataset)
|
| 91 |
+
tq.set_description('Epoch {}, lr {}'.format(epoch, lr))
|
| 92 |
+
i = 0
|
| 93 |
+
for data in tq:
|
| 94 |
+
inputs, targets = self.model.get_input(data)
|
| 95 |
+
outputs, decomp1, decomp2 = self.netG(inputs)
|
| 96 |
+
self.optimizer_G.zero_grad()
|
| 97 |
+
loss_G = self.criterionG(outputs, targets, inputs)
|
| 98 |
+
loss_G.backward()
|
| 99 |
+
self.optimizer_G.step()
|
| 100 |
+
self.metric_counter.add_losses(loss_G.item())
|
| 101 |
+
curr_psnr, curr_ssim, img_for_vis = self.model.get_images_and_metrics(inputs, outputs, targets)
|
| 102 |
+
self.metric_counter.add_metrics(curr_psnr, curr_ssim)
|
| 103 |
+
tq.set_postfix(loss=self.metric_counter.loss_message())
|
| 104 |
+
if not i:
|
| 105 |
+
self.metric_counter.add_image(img_for_vis, tag='train')
|
| 106 |
+
i += 1
|
| 107 |
+
if i > len(self.train_dataset):
|
| 108 |
+
break
|
| 109 |
+
tq.close()
|
| 110 |
+
self.metric_counter.write_to_tensorboard(epoch)
|
| 111 |
+
|
| 112 |
+
def _validate(self, epoch):
|
| 113 |
+
self.metric_counter.clear()
|
| 114 |
+
epoch_size = config.get('val_batches_per_epoch') or len(self.val_dataset)
|
| 115 |
+
tq = tqdm.tqdm(self.val_dataset)
|
| 116 |
+
tq.set_description('Validation')
|
| 117 |
+
i = 0
|
| 118 |
+
for data in tq:
|
| 119 |
+
with torch.no_grad():
|
| 120 |
+
inputs, targets = self.model.get_input(data)
|
| 121 |
+
outputs, decomp1, decomp2 = self.netG(inputs)
|
| 122 |
+
loss_G = self.criterionG(outputs, targets, inputs)
|
| 123 |
+
self.metric_counter.add_losses(loss_G.item())
|
| 124 |
+
curr_psnr, curr_ssim, img_for_vis = self.model.get_images_and_metrics(inputs, outputs, targets)
|
| 125 |
+
self.metric_counter.add_metrics(curr_psnr, curr_ssim)
|
| 126 |
+
if not i:
|
| 127 |
+
self.metric_counter.add_image(img_for_vis, tag='val')
|
| 128 |
+
i += 1
|
| 129 |
+
if i > len(self.train_dataset):
|
| 130 |
+
break
|
| 131 |
+
tq.close()
|
| 132 |
+
self.metric_counter.write_to_tensorboard(epoch, validation=True)
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
def _get_optim(self, params):
|
| 136 |
+
if self.config['optimizer']['name'] == 'adam':
|
| 137 |
+
optimizer = optim.Adam(params, lr=self.config['optimizer']['lr'])
|
| 138 |
+
elif self.config['optimizer']['name'] == 'adamw':
|
| 139 |
+
optimizer = optim.AdamW(params, lr=0.001, weight_decay=0.001, betas=(0.9,0.9))
|
| 140 |
+
else:
|
| 141 |
+
raise ValueError("Optimizer [%s] not recognized." % self.config['optimizer']['name'])
|
| 142 |
+
return optimizer
|
| 143 |
+
|
| 144 |
+
def _get_scheduler(self, optimizer):
|
| 145 |
+
if self.config['scheduler']['name'] == 'cosine':
|
| 146 |
+
scheduler = CosineAnnealingLR(optimizer, T_max=self.config['num_epochs'], eta_min=self.config['scheduler']['min_lr'])
|
| 147 |
+
else:
|
| 148 |
+
raise ValueError("Scheduler [%s] not recognized." % self.config['scheduler']['name'])
|
| 149 |
+
return scheduler
|
| 150 |
+
|
| 151 |
+
def _init_params(self):
|
| 152 |
+
self.criterionG = get_loss(self.config['model'])
|
| 153 |
+
self.netG = get_nets(self.config['model'])
|
| 154 |
+
self.netG.cuda()
|
| 155 |
+
self.model = get_model(self.config['model'])
|
| 156 |
+
self.optimizer_G = self._get_optim(filter(lambda p: p.requires_grad, self.netG.parameters()))
|
| 157 |
+
self.scheduler_G = self._get_scheduler(self.optimizer_G)
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
if __name__ == '__main__':
|
| 161 |
+
with open(os.path.join('config/', args.job_name, 'config_stage1.yaml'), 'r') as f:
|
| 162 |
+
config = yaml.safe_load(f)
|
| 163 |
+
|
| 164 |
+
# setup
|
| 165 |
+
torch.backends.cudnn.enabled = True
|
| 166 |
+
torch.backends.cudnn.benchmark = True
|
| 167 |
+
|
| 168 |
+
# set random seed
|
| 169 |
+
seed = 666
|
| 170 |
+
torch.manual_seed(seed)
|
| 171 |
+
torch.cuda.manual_seed(seed)
|
| 172 |
+
random.seed(seed)
|
| 173 |
+
np.random.seed(seed)
|
| 174 |
+
|
| 175 |
+
batch_size = config.pop('batch_size')
|
| 176 |
+
get_dataloader = partial(DataLoader, batch_size=batch_size, num_workers=cpu_count(), shuffle=True, drop_last=False)
|
| 177 |
+
|
| 178 |
+
datasets = map(config.pop, ('train', 'val'))
|
| 179 |
+
datasets = map(PairedDataset.from_config, datasets)
|
| 180 |
+
train, val = map(get_dataloader, datasets)
|
| 181 |
+
trainer = Trainer(config, train=train, val=val)
|
| 182 |
+
trainer.train()
|
train_XYScanNet_stage2.py
ADDED
|
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from functools import partial
|
| 3 |
+
import os
|
| 4 |
+
import cv2
|
| 5 |
+
import torch
|
| 6 |
+
import torch.optim as optim
|
| 7 |
+
import tqdm
|
| 8 |
+
import yaml
|
| 9 |
+
from joblib import cpu_count
|
| 10 |
+
from torch.utils.data import DataLoader
|
| 11 |
+
import random
|
| 12 |
+
from dataset import PairedDataset
|
| 13 |
+
from metric_counter import MetricCounter
|
| 14 |
+
from models.losses import get_loss
|
| 15 |
+
from models.models import get_model
|
| 16 |
+
from models.networks import get_nets
|
| 17 |
+
import numpy as np
|
| 18 |
+
from torch.optim.lr_scheduler import CosineAnnealingLR
|
| 19 |
+
cv2.setNumThreads(0)
|
| 20 |
+
|
| 21 |
+
import argparse
|
| 22 |
+
parser = argparse.ArgumentParser(description='Image motion deblurring evaluation on GoPro/HIDE')
|
| 23 |
+
parser.add_argument('--job_name', default='xyscannet',
|
| 24 |
+
type=str, help='current job s name')
|
| 25 |
+
args = parser.parse_args()
|
| 26 |
+
|
| 27 |
+
class Trainer:
|
| 28 |
+
def __init__(self, config, train: DataLoader, val: DataLoader):
|
| 29 |
+
self.config = config
|
| 30 |
+
self.train_dataset = train
|
| 31 |
+
self.val_dataset = val
|
| 32 |
+
self.metric_counter = MetricCounter(config['experiment_desc'])
|
| 33 |
+
|
| 34 |
+
def train(self):
|
| 35 |
+
self._init_params()
|
| 36 |
+
start_epoch = 0
|
| 37 |
+
print("The current job is: ", args.job_name)
|
| 38 |
+
model_dir = os.path.join('results/', args.job_name, 'models')
|
| 39 |
+
if os.path.exists(os.path.join(model_dir, 'last_XYScanNet_stage2.pth')):
|
| 40 |
+
print('resume learning')
|
| 41 |
+
training_state = (torch.load(os.path.join(model_dir, 'last_XYScanNet_stage2.pth')))
|
| 42 |
+
start_epoch = training_state['epoch'] + 1
|
| 43 |
+
new_weight = self.netG.state_dict()
|
| 44 |
+
new_weight.update(training_state['model_state'])
|
| 45 |
+
self.netG.load_state_dict(new_weight)
|
| 46 |
+
new_optimizer = self.optimizer_G.state_dict()
|
| 47 |
+
new_optimizer.update(training_state['optimizer_state'])
|
| 48 |
+
self.optimizer_G.load_state_dict(new_optimizer)
|
| 49 |
+
new_scheduler = self.scheduler_G.state_dict()
|
| 50 |
+
new_scheduler.update(training_state['scheduler_state'])
|
| 51 |
+
self.scheduler_G.load_state_dict(new_scheduler)
|
| 52 |
+
else:
|
| 53 |
+
print('load_weights_stage1')
|
| 54 |
+
training_state = (torch.load(os.path.join(model_dir, 'final_XYScanNet_stage1.pth')))
|
| 55 |
+
new_weight = self.netG.state_dict()
|
| 56 |
+
new_weight.update(training_state)
|
| 57 |
+
self.netG.load_state_dict(new_weight)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
for epoch in range(start_epoch, config['num_epochs']):
|
| 61 |
+
self._run_epoch(epoch)
|
| 62 |
+
if epoch % 30 == 0 or epoch == (config['num_epochs']-1):
|
| 63 |
+
self._validate(epoch)
|
| 64 |
+
self.scheduler_G.step()
|
| 65 |
+
|
| 66 |
+
scheduler_state = self.scheduler_G.state_dict()
|
| 67 |
+
training_state = {'epoch': epoch, 'model_state': self.netG.state_dict(),
|
| 68 |
+
'scheduler_state': scheduler_state, 'optimizer_state': self.optimizer_G.state_dict()}
|
| 69 |
+
if self.metric_counter.update_best_model():
|
| 70 |
+
torch.save(training_state['model_state'],
|
| 71 |
+
os.path.join(model_dir, 'best_{}.pth'.format(self.config['experiment_desc'])))
|
| 72 |
+
if epoch % 200 == 0:
|
| 73 |
+
torch.save(training_state,
|
| 74 |
+
os.path.join(model_dir, 'last_{}_{}.pth'.format(self.config['experiment_desc'], epoch)))
|
| 75 |
+
|
| 76 |
+
if epoch == (config['num_epochs']-1):
|
| 77 |
+
torch.save(training_state['model_state'],
|
| 78 |
+
os.path.join(model_dir, 'final_{}.pth'.format(self.config['experiment_desc'])))
|
| 79 |
+
|
| 80 |
+
torch.save(training_state,
|
| 81 |
+
os.path.join(model_dir, 'last_{}.pth'.format(self.config['experiment_desc'])))
|
| 82 |
+
logging.debug("Experiment Name: %s, Epoch: %d, Loss: %s" % (
|
| 83 |
+
self.config['experiment_desc'], epoch, self.metric_counter.loss_message()))
|
| 84 |
+
|
| 85 |
+
def _run_epoch(self, epoch):
|
| 86 |
+
self.metric_counter.clear()
|
| 87 |
+
for param_group in self.optimizer_G.param_groups:
|
| 88 |
+
lr = param_group['lr']
|
| 89 |
+
|
| 90 |
+
epoch_size = config.get('train_batches_per_epoch') or len(self.train_dataset)
|
| 91 |
+
tq = tqdm.tqdm(self.train_dataset)
|
| 92 |
+
tq.set_description('Epoch {}, lr {}'.format(epoch, lr))
|
| 93 |
+
i = 0
|
| 94 |
+
for data in tq:
|
| 95 |
+
inputs, targets = self.model.get_input(data)
|
| 96 |
+
outputs, decomp1, decomp2 = self.netG(inputs)
|
| 97 |
+
#outputs = self.netG(inputs)
|
| 98 |
+
self.optimizer_G.zero_grad()
|
| 99 |
+
loss_G = self.criterionG(outputs, targets, inputs)
|
| 100 |
+
loss_G.backward()
|
| 101 |
+
self.optimizer_G.step()
|
| 102 |
+
self.metric_counter.add_losses(loss_G.item())
|
| 103 |
+
curr_psnr, curr_ssim, img_for_vis = self.model.get_images_and_metrics(inputs, outputs, targets)
|
| 104 |
+
self.metric_counter.add_metrics(curr_psnr, curr_ssim)
|
| 105 |
+
tq.set_postfix(loss=self.metric_counter.loss_message())
|
| 106 |
+
if not i:
|
| 107 |
+
self.metric_counter.add_image(img_for_vis, tag='train')
|
| 108 |
+
i += 1
|
| 109 |
+
if i > len(self.train_dataset):
|
| 110 |
+
break
|
| 111 |
+
tq.close()
|
| 112 |
+
self.metric_counter.write_to_tensorboard(epoch)
|
| 113 |
+
|
| 114 |
+
def _validate(self, epoch):
|
| 115 |
+
self.metric_counter.clear()
|
| 116 |
+
epoch_size = config.get('val_batches_per_epoch') or len(self.val_dataset)
|
| 117 |
+
tq = tqdm.tqdm(self.val_dataset)
|
| 118 |
+
tq.set_description('Validation')
|
| 119 |
+
i = 0
|
| 120 |
+
for data in tq:
|
| 121 |
+
with torch.no_grad():
|
| 122 |
+
inputs, targets = self.model.get_input(data)
|
| 123 |
+
outputs, decomp1, decomp2 = self.netG(inputs)
|
| 124 |
+
#outputs = self.netG(inputs)
|
| 125 |
+
loss_G = self.criterionG(outputs, targets, inputs)
|
| 126 |
+
self.metric_counter.add_losses(loss_G.item())
|
| 127 |
+
curr_psnr, curr_ssim, img_for_vis = self.model.get_images_and_metrics(inputs, outputs, targets)
|
| 128 |
+
self.metric_counter.add_metrics(curr_psnr, curr_ssim)
|
| 129 |
+
if not i:
|
| 130 |
+
self.metric_counter.add_image(img_for_vis, tag='val')
|
| 131 |
+
i += 1
|
| 132 |
+
if i > len(self.train_dataset):
|
| 133 |
+
break
|
| 134 |
+
tq.close()
|
| 135 |
+
self.metric_counter.write_to_tensorboard(epoch, validation=True)
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def _get_optim(self, params):
|
| 139 |
+
if self.config['optimizer']['name'] == 'adam':
|
| 140 |
+
optimizer = optim.Adam(params, lr=self.config['optimizer']['lr'])
|
| 141 |
+
else:
|
| 142 |
+
raise ValueError("Optimizer [%s] not recognized." % self.config['optimizer']['name'])
|
| 143 |
+
return optimizer
|
| 144 |
+
|
| 145 |
+
def _get_scheduler(self, optimizer):
|
| 146 |
+
if self.config['scheduler']['name'] == 'cosine':
|
| 147 |
+
scheduler = CosineAnnealingLR(optimizer, T_max=self.config['num_epochs'], eta_min=self.config['scheduler']['min_lr'])
|
| 148 |
+
else:
|
| 149 |
+
raise ValueError("Scheduler [%s] not recognized." % self.config['scheduler']['name'])
|
| 150 |
+
return scheduler
|
| 151 |
+
|
| 152 |
+
def _init_params(self):
|
| 153 |
+
self.criterionG = get_loss(self.config['model'])
|
| 154 |
+
self.netG = get_nets(self.config['model'])
|
| 155 |
+
self.netG.cuda()
|
| 156 |
+
self.model = get_model(self.config['model'])
|
| 157 |
+
self.optimizer_G = self._get_optim(filter(lambda p: p.requires_grad, self.netG.parameters()))
|
| 158 |
+
self.scheduler_G = self._get_scheduler(self.optimizer_G)
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
if __name__ == '__main__':
|
| 162 |
+
with open(os.path.join('config/', args.job_name, 'config_stage2.yaml'), 'r') as f:
|
| 163 |
+
config = yaml.safe_load(f)
|
| 164 |
+
# setup
|
| 165 |
+
torch.backends.cudnn.enabled = True
|
| 166 |
+
torch.backends.cudnn.benchmark = True
|
| 167 |
+
|
| 168 |
+
# set random seed
|
| 169 |
+
seed = 666
|
| 170 |
+
torch.manual_seed(seed)
|
| 171 |
+
torch.cuda.manual_seed(seed)
|
| 172 |
+
random.seed(seed)
|
| 173 |
+
np.random.seed(seed)
|
| 174 |
+
|
| 175 |
+
batch_size = config.pop('batch_size')
|
| 176 |
+
get_dataloader = partial(DataLoader, batch_size=batch_size, num_workers=cpu_count(), shuffle=True, drop_last=False)
|
| 177 |
+
|
| 178 |
+
datasets = map(config.pop, ('train', 'val'))
|
| 179 |
+
datasets = map(PairedDataset.from_config, datasets)
|
| 180 |
+
train, val = map(get_dataloader, datasets)
|
| 181 |
+
trainer = Trainer(config, train=train, val=val)
|
| 182 |
+
trainer.train()
|
util/__init__.py
ADDED
|
File without changes
|
util/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (143 Bytes). View file
|
|
|
util/__pycache__/__init__.cpython-36.pyc
ADDED
|
Binary file (126 Bytes). View file
|
|
|
util/__pycache__/__init__.cpython-38.pyc
ADDED
|
Binary file (145 Bytes). View file
|
|
|