HanzhouLiu commited on
Commit
b56342d
·
1 Parent(s): 9f55394

Add application file

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. README.md +17 -7
  2. app.py +123 -0
  3. aug.py +56 -0
  4. config/xyscannetp_gopro/config_stage1.yaml +40 -0
  5. config/xyscannetp_gopro/config_stage2.yaml +41 -0
  6. config/xyscannetp_realj/config_stage2.yml +40 -0
  7. config/xyscannetp_realr/config_stage2.yml +40 -0
  8. dataset.py +140 -0
  9. datasets/datasets.txt +2 -0
  10. evaluate_NIQE.m +57 -0
  11. evaluate_RealBlur_J.py +117 -0
  12. evaluate_RealBlur_R.py +110 -0
  13. evaluation_GoPro.m +60 -0
  14. evaluation_HIDE.m +60 -0
  15. examples/blur1.png +3 -0
  16. examples/blur2.png +3 -0
  17. examples/blur3.png +3 -0
  18. examples/blur4.png +3 -0
  19. examples/blur5.png +3 -0
  20. license +37 -0
  21. metric_counter.py +55 -0
  22. models/XYScanNet.py +737 -0
  23. models/XYScanNetP.py +737 -0
  24. models/__init__.py +0 -0
  25. models/__pycache__/XYScanNet.cpython-38.pyc +0 -0
  26. models/__pycache__/XYScanNetP.cpython-38.pyc +0 -0
  27. models/__pycache__/__init__.cpython-38.pyc +0 -0
  28. models/__pycache__/networks.cpython-38.pyc +0 -0
  29. models/losses.py +233 -0
  30. models/models.py +36 -0
  31. models/networks.py +16 -0
  32. models/sota/FFTformer.py +324 -0
  33. models/sota/Restormer.py +340 -0
  34. models/sota/Stripformer.py +429 -0
  35. models/sota/XYScanNet.py +754 -0
  36. out/Results.txt +1 -0
  37. predict_GoPro_test_results.py +89 -0
  38. predict_HIDE_test_results.py +69 -0
  39. predict_RWBI_test_results.py +88 -0
  40. predict_RealBlur_J_test_results.py +97 -0
  41. predict_RealBlur_R_test_results.py +96 -0
  42. requirements.txt +11 -0
  43. results/xyscannetp_gopro/models/best_XYScanNet_stage2.pth +3 -0
  44. schedulers.py +59 -0
  45. train_XYScanNet_stage1.py +182 -0
  46. train_XYScanNet_stage2.py +182 -0
  47. util/__init__.py +0 -0
  48. util/__pycache__/__init__.cpython-310.pyc +0 -0
  49. util/__pycache__/__init__.cpython-36.pyc +0 -0
  50. util/__pycache__/__init__.cpython-38.pyc +0 -0
README.md CHANGED
@@ -1,13 +1,23 @@
1
  ---
2
- title: XYScanNet Demo
3
- emoji: 👁
4
- colorFrom: red
5
- colorTo: purple
6
  sdk: gradio
7
- sdk_version: 5.49.1
8
  app_file: app.py
9
  pinned: false
10
- license: apache-2.0
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: XYScanNet
3
+ emoji: 🚀
4
+ colorFrom: blue
5
+ colorTo: indigo
6
  sdk: gradio
7
+ sdk_version: "4.44.1"
8
  app_file: app.py
9
  pinned: false
 
10
  ---
11
 
12
+ # XYScanNet: Mamba-based Image Deblurring Demo
13
+
14
+ This Space runs the **XYScanNet** deblurring model on GPU using **Gradio**.
15
+ Upload a blurry image, and the model will restore a sharp version automatically.
16
+
17
+ 🧠 **Tech Highlights**
18
+ - Based on the **Mamba selective state space model**
19
+ - Implements cross-directional strip attention (horizontal & vertical)
20
+ - Runs efficiently on GPU with automatic padding
21
+
22
+ 👤 Author: [Hanzhou Liu](https://huggingface.co/HanzhouLiu)
23
+ 📦 Model weights: [HanzhouLiu/XYScanNet-weights](https://huggingface.co/spaces/HanzhouLiu/XYScanNet_Demo)
app.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import spaces
3
+ import torch
4
+ import torchvision
5
+ import torch.nn.functional as F
6
+ from torch.autograd import Variable
7
+ import numpy as np
8
+ from PIL import Image
9
+ import yaml
10
+ import os
11
+ from models.networks import get_generator
12
+
13
+
14
+ # ===========================
15
+ # 1. Device setup
16
+ # ===========================
17
+ # Automatically choose GPU if available
18
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
+ print(f"🔥 Using device: {device}")
20
+
21
+
22
+ # ===========================
23
+ # 2. Model Loading
24
+ # ===========================
25
+ def load_model(job_name="xyscannetp_gopro"):
26
+ """
27
+ Load the pretrained XYScanNet model on CPU or GPU automatically.
28
+ """
29
+ cfg_path = os.path.join("config", job_name, "config_stage2.yaml")
30
+ with open(cfg_path, "r") as f:
31
+ config = yaml.safe_load(f)
32
+
33
+ weights_path = os.path.join(
34
+ "results", job_name, "models", f"best_{config['experiment_desc']}.pth"
35
+ )
36
+ print(f"🔹 Loading model from {weights_path}")
37
+
38
+ model = get_generator(config["model"])
39
+ model.load_state_dict(torch.load(weights_path, map_location=device))
40
+ model.eval().to(device)
41
+
42
+ print(f"✅ Model loaded on {device}")
43
+ return model
44
+
45
+
46
+ print("Initializing XYScanNet model...")
47
+ MODEL = load_model()
48
+ print("Model ready.")
49
+
50
+
51
+ # ===========================
52
+ # 3. Helper functions
53
+ # ===========================
54
+ def pad_to_multiple_of_8(img_tensor):
55
+ """
56
+ Pad the image tensor so that both height and width are multiples of 8.
57
+ """
58
+ _, _, h, w = img_tensor.shape
59
+ pad_h = (8 - h % 8) % 8
60
+ pad_w = (8 - w % 8) % 8
61
+ img_tensor = F.pad(img_tensor, (0, pad_w, 0, pad_h), mode="reflect")
62
+ return img_tensor, h, w
63
+
64
+
65
+ def crop_back(img_tensor, orig_h, orig_w):
66
+ """Crop output back to original image size."""
67
+ return img_tensor[:, :, :orig_h, :orig_w]
68
+
69
+
70
+ # ===========================
71
+ # 4. Inference Function
72
+ # ===========================
73
+ # The decorator below *requests* GPU if available,
74
+ # but won't crash if only CPU exists.
75
+ @spaces.GPU
76
+ def run_deblur(input_image: Image.Image):
77
+ """
78
+ Run deblurring inference on GPU if available, else CPU.
79
+ """
80
+ # Convert PIL RGB → Tensor [B,C,H,W] normalized to [-0.5,0.5]
81
+ img = np.array(input_image.convert("RGB"))
82
+ img_tensor = (
83
+ torch.from_numpy(np.transpose(img / 255.0, (2, 0, 1)).astype("float32")) - 0.5
84
+ )
85
+ img_tensor = Variable(img_tensor.unsqueeze(0)).to(device)
86
+
87
+ # Pad to valid window size
88
+ img_tensor, orig_h, orig_w = pad_to_multiple_of_8(img_tensor)
89
+
90
+ # Inference
91
+ with torch.no_grad():
92
+ result_image, _, _ = MODEL(img_tensor)
93
+ result_image = result_image + 0.5
94
+ result_image = crop_back(result_image, orig_h, orig_w)
95
+
96
+ # Convert to PIL Image for display
97
+ out_img = result_image.squeeze(0).clamp(0, 1).cpu()
98
+ out_pil = torchvision.transforms.ToPILImage()(out_img)
99
+ return out_pil
100
+
101
+
102
+ # ===========================
103
+ # 5. Gradio Interface
104
+ # ===========================
105
+ demo = gr.Interface(
106
+ fn=run_deblur,
107
+ inputs=gr.Image(type="pil", label="Upload a Blurry Image"),
108
+ outputs=gr.Image(type="pil", label="Deblurred Result"),
109
+ title="XYScanNet: Mamba-based Image Deblurring (GPU Demo)",
110
+ description=(
111
+ "Upload a blurry image to see how XYScanNet restores it using a Mamba-based vision state-space model."
112
+ ),
113
+ examples=[
114
+ ["examples/blur1.jpg"],
115
+ ["examples/blur2.png"],
116
+ ["examples/blur3.jpg"],
117
+ ["examples/blur4.jpg"],
118
+ ["examples/blur5.jpg"],
119
+ ],
120
+ allow_flagging="never",
121
+ )
122
+ if __name__ == "__main__":
123
+ demo.launch()
aug.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ import albumentations as albu
4
+ from torchvision import transforms
5
+
6
+ def get_transforms(size: int, scope: str = 'geometric', crop='random'):
7
+ augs = {'strong': albu.Compose([albu.HorizontalFlip(),
8
+ albu.ShiftScaleRotate(shift_limit=0.0, scale_limit=0.2, rotate_limit=20, p=.4),
9
+ albu.ElasticTransform(),
10
+ albu.OpticalDistortion(),
11
+ albu.OneOf([
12
+ albu.CLAHE(clip_limit=2),
13
+ albu.Sharpen(),
14
+ albu.Emboss(),
15
+ albu.RandomBrightnessContrast(),
16
+ albu.RandomGamma()
17
+ ], p=0.5),
18
+ albu.OneOf([
19
+ albu.RGBShift(),
20
+ albu.HueSaturationValue(),
21
+ ], p=0.5),
22
+ ]),
23
+ 'weak': albu.Compose([albu.HorizontalFlip(),
24
+ ]),
25
+ 'geometric': albu.Compose([albu.HorizontalFlip(),
26
+ albu.VerticalFlip(),
27
+ albu.RandomRotate90(),
28
+ ]),
29
+ 'None': None
30
+ }
31
+
32
+ aug_fn = augs[scope]
33
+ crop_fn = {'random': albu.RandomCrop(size, size, always_apply=True),
34
+ 'center': albu.CenterCrop(size, size, always_apply=True)}[crop]
35
+
36
+ pipeline = albu.Compose([aug_fn, crop_fn], additional_targets={'target': 'image'})
37
+
38
+
39
+ def process(a, b):
40
+ r = pipeline(image=a, target=b)
41
+ return r['image'], r['target']
42
+
43
+ return process
44
+
45
+
46
+ def get_normalize():
47
+ transform = transforms.Compose([
48
+ transforms.ToTensor()
49
+ ])
50
+
51
+ def process(a, b):
52
+ image = transform(a).permute(1, 2, 0) - 0.5
53
+ target = transform(b).permute(1, 2, 0) - 0.5
54
+ return image, target
55
+
56
+ return process
config/xyscannetp_gopro/config_stage1.yaml ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ experiment_desc: XYScanNet_stage1
3
+
4
+ train:
5
+ files_a: /scratch/user/hanzhou1996/datasets/deblur/GOPRO_/train/blur/**/*.png
6
+ files_b: /scratch/user/hanzhou1996/datasets/deblur/GOPRO_/train/sharp/**/*.png
7
+ size: &SIZE 252
8
+ crop: random
9
+ preload: &PRELOAD false
10
+ preload_size: &PRELOAD_SIZE 0
11
+ bounds: [0, 1]
12
+ scope: geometric
13
+
14
+ val:
15
+ files_a: /scratch/user/hanzhou1996/datasets/deblur/GOPRO_/test/blur/**/*.png
16
+ files_b: /scratch/user/hanzhou1996/datasets/deblur/GOPRO_/test/sharp/**/*.png
17
+ size: *SIZE
18
+ scope: None
19
+ crop: random
20
+ preload: *PRELOAD
21
+ preload_size: *PRELOAD_SIZE
22
+ bounds: [0, 1]
23
+
24
+ model:
25
+ g_name: XYScanNetP
26
+ content_loss: Stripformer_Loss
27
+
28
+ num_epochs: 4000
29
+ train_batches_per_epoch: 2103
30
+ val_batches_per_epoch: 1111
31
+ batch_size: 16
32
+ image_size: [252, 252]
33
+
34
+ optimizer:
35
+ name: adam
36
+ lr: 0.00022
37
+ scheduler:
38
+ name: cosine
39
+ start_epoch: 50
40
+ min_lr: 0.0000001
config/xyscannetp_gopro/config_stage2.yaml ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ experiment_desc: XYScanNet_stage2
3
+
4
+ train:
5
+ #/mnt/g/RESEARCH/PHD/Motion_Deblurred/datasets/GOPRO_
6
+ files_a: /scratch/user/hanzhou1996/datasets/deblur/GOPRO_/train/blur/**/*.png
7
+ files_b: /scratch/user/hanzhou1996/datasets/deblur/GOPRO_/train/sharp/**/*.png
8
+ size: &SIZE 320
9
+ crop: random
10
+ preload: &PRELOAD false
11
+ preload_size: &PRELOAD_SIZE 0
12
+ bounds: [0, 1]
13
+ scope: geometric
14
+
15
+ val:
16
+ files_a: /scratch/user/hanzhou1996/datasets/deblur/GOPRO_/test/blur/**/*.png
17
+ files_b: /scratch/user/hanzhou1996/datasets/deblur/GOPRO_/test/sharp/**/*.png
18
+ size: *SIZE
19
+ scope: None
20
+ crop: random
21
+ preload: *PRELOAD
22
+ preload_size: *PRELOAD_SIZE
23
+ bounds: [0, 1]
24
+
25
+ model:
26
+ g_name: XYScanNetP
27
+ content_loss: Stripformer_Loss
28
+
29
+ num_epochs: 4000
30
+ train_batches_per_epoch: 2103
31
+ val_batches_per_epoch: 1111
32
+ batch_size: 8
33
+ image_size: [320, 320]
34
+
35
+ optimizer:
36
+ name: adam
37
+ lr: 0.00015
38
+ scheduler:
39
+ name: cosine
40
+ start_epoch: 50
41
+ min_lr: 0.0000001
config/xyscannetp_realj/config_stage2.yml ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ experiment_desc: XYScanNet_stage2
3
+
4
+ train:
5
+ files_a: /scratch/user/hanzhou1996/datasets/deblur/RealBlur_J/train/trainA/*.png
6
+ files_b: /scratch/user/hanzhou1996/datasets/deblur/RealBlur_J/train/trainB/*.png
7
+ size: &SIZE 320
8
+ crop: random
9
+ preload: &PRELOAD false
10
+ preload_size: &PRELOAD_SIZE 0
11
+ bounds: [0, 1]
12
+ scope: geometric
13
+
14
+ val:
15
+ files_a: /scratch/user/hanzhou1996/datasets/deblur/RealBlur_J/test/testA/*.png
16
+ files_b: /scratch/user/hanzhou1996/datasets/deblur/RealBlur_J/test/testB/*.png
17
+ size: *SIZE
18
+ scope: None
19
+ crop: random
20
+ preload: *PRELOAD
21
+ preload_size: *PRELOAD_SIZE
22
+ bounds: [0, 1]
23
+
24
+ model:
25
+ g_name: XYScanNetP
26
+ content_loss: Stripformer_Loss
27
+
28
+ num_epochs: 2000
29
+ train_batches_per_epoch: 3758
30
+ val_batches_per_epoch: 980
31
+ batch_size: 8
32
+ image_size: [320, 320]
33
+
34
+ optimizer:
35
+ name: adam
36
+ lr: 0.0001
37
+ scheduler:
38
+ name: cosine
39
+ start_epoch: 50
40
+ min_lr: 0.0000001
config/xyscannetp_realr/config_stage2.yml ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ experiment_desc: XYScanNet_stage2
3
+
4
+ train:
5
+ files_a: /scratch/user/hanzhou1996/datasets/deblur/RealBlur_R/train/trainA/*.png
6
+ files_b: /scratch/user/hanzhou1996/datasets/deblur/RealBlur_R/train/trainB/*.png
7
+ size: &SIZE 320
8
+ crop: random
9
+ preload: &PRELOAD false
10
+ preload_size: &PRELOAD_SIZE 0
11
+ bounds: [0, 1]
12
+ scope: geometric
13
+
14
+ val:
15
+ files_a: /scratch/user/hanzhou1996/datasets/deblur/RealBlur_R/test/testA/*.png
16
+ files_b: /scratch/user/hanzhou1996/datasets/deblur/RealBlur_R/test/testB/*.png
17
+ size: *SIZE
18
+ scope: None
19
+ crop: random
20
+ preload: *PRELOAD
21
+ preload_size: *PRELOAD_SIZE
22
+ bounds: [0, 1]
23
+
24
+ model:
25
+ g_name: XYScanNetP
26
+ content_loss: Stripformer_Loss
27
+
28
+ num_epochs: 2000
29
+ train_batches_per_epoch: 3758
30
+ val_batches_per_epoch: 980
31
+ batch_size: 8
32
+ image_size: [320, 320]
33
+
34
+ optimizer:
35
+ name: adam
36
+ lr: 0.0001
37
+ scheduler:
38
+ name: cosine
39
+ start_epoch: 50
40
+ min_lr: 0.0000001
dataset.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from copy import deepcopy
3
+ from functools import partial
4
+ from glob import glob
5
+ from hashlib import sha1
6
+ from typing import Callable, Iterable, Optional, Tuple
7
+
8
+ import cv2
9
+ import numpy as np
10
+ from glog import logger
11
+ from joblib import Parallel, cpu_count, delayed
12
+ from skimage.io import imread
13
+ from torch.utils.data import Dataset
14
+ from tqdm import tqdm
15
+
16
+ import aug
17
+
18
+
19
+ def subsample(data: Iterable, bounds: Tuple[float, float], hash_fn: Callable, n_buckets=100, salt='', verbose=True):
20
+ data = list(data)
21
+ buckets = split_into_buckets(data, n_buckets=n_buckets, salt=salt, hash_fn=hash_fn)
22
+
23
+ lower_bound, upper_bound = [x * n_buckets for x in bounds]
24
+ msg = f'Subsampling buckets from {lower_bound} to {upper_bound}, total buckets number is {n_buckets}'
25
+ if salt:
26
+ msg += f'; salt is {salt}'
27
+ if verbose:
28
+ logger.info(msg)
29
+ return np.array([sample for bucket, sample in zip(buckets, data) if lower_bound <= bucket < upper_bound])
30
+
31
+
32
+ def hash_from_paths(x: Tuple[str, str], salt: str = '') -> str:
33
+ path_a, path_b = x
34
+ names = ''.join(map(os.path.basename, (path_a, path_b)))
35
+ return sha1(f'{names}_{salt}'.encode()).hexdigest()
36
+
37
+
38
+ def split_into_buckets(data: Iterable, n_buckets: int, hash_fn: Callable, salt=''):
39
+ hashes = map(partial(hash_fn, salt=salt), data)
40
+ return np.array([int(x, 16) % n_buckets for x in hashes])
41
+
42
+
43
+ def _read_img(x: str):
44
+ img = cv2.imread(x)
45
+ if img is None:
46
+ logger.warning(f'Can not read image {x} with OpenCV, switching to scikit-image')
47
+ img = imread(x)
48
+ return img
49
+
50
+
51
+ class PairedDataset(Dataset):
52
+ def __init__(self,
53
+ files_a: Tuple[str],
54
+ files_b: Tuple[str],
55
+ transform_fn: Callable,
56
+ normalize_fn: Callable,
57
+ corrupt_fn: Optional[Callable] = None,
58
+ preload: bool = True,
59
+ preload_size: Optional[int] = 0,
60
+ verbose=True):
61
+
62
+ assert len(files_a) == len(files_b)
63
+
64
+ self.preload = preload
65
+ self.data_a = files_a
66
+ self.data_b = files_b
67
+ self.verbose = verbose
68
+ self.corrupt_fn = corrupt_fn
69
+ self.transform_fn = transform_fn
70
+ self.normalize_fn = normalize_fn
71
+ logger.info(f'Dataset has been created with {len(self.data_a)} samples')
72
+
73
+ if preload:
74
+ preload_fn = partial(self._bulk_preload, preload_size=preload_size)
75
+ if files_a == files_b:
76
+ self.data_a = self.data_b = preload_fn(self.data_a)
77
+ else:
78
+ self.data_a, self.data_b = map(preload_fn, (self.data_a, self.data_b))
79
+ self.preload = True
80
+
81
+ def _bulk_preload(self, data: Iterable[str], preload_size: int):
82
+ jobs = [delayed(self._preload)(x, preload_size=preload_size) for x in data]
83
+ jobs = tqdm(jobs, desc='preloading images', disable=not self.verbose)
84
+ return Parallel(n_jobs=cpu_count(), backend='threading')(jobs)
85
+
86
+ @staticmethod
87
+ def _preload(x: str, preload_size: int):
88
+ img = _read_img(x)
89
+ if preload_size:
90
+ h, w, *_ = img.shape
91
+ h_scale = preload_size / h
92
+ w_scale = preload_size / w
93
+ scale = max(h_scale, w_scale)
94
+ img = cv2.resize(img, fx=scale, fy=scale, dsize=None)
95
+ assert min(img.shape[:2]) >= preload_size, f'weird img shape: {img.shape}'
96
+ return img
97
+
98
+ def _preprocess(self, img, res):
99
+ def transpose(x):
100
+ return np.transpose(x, (2, 0, 1))
101
+
102
+ return map(transpose, self.normalize_fn(img, res))
103
+
104
+ def __len__(self):
105
+ return len(self.data_a)
106
+
107
+ def __getitem__(self, idx):
108
+ a, b = self.data_a[idx], self.data_b[idx]
109
+ if not self.preload:
110
+ a, b = map(_read_img, (a, b))
111
+ a, b = self.transform_fn(a, b)
112
+ if self.corrupt_fn is not None:
113
+ a = self.corrupt_fn(a)
114
+ a, b = self._preprocess(a, b)
115
+ return {'a': a, 'b': b}
116
+
117
+ @staticmethod
118
+ def from_config(config):
119
+ config = deepcopy(config)
120
+ files_a, files_b = map(lambda x: sorted(glob(config[x], recursive=True)), ('files_a', 'files_b'))
121
+ transform_fn = aug.get_transforms(size=config['size'], scope=config['scope'], crop=config['crop'])
122
+ normalize_fn = aug.get_normalize()
123
+
124
+ hash_fn = hash_from_paths
125
+ # ToDo: add more hash functions
126
+ verbose = config.get('verbose', True)
127
+ data = subsample(data=zip(files_a, files_b),
128
+ bounds=config.get('bounds', (0, 1)),
129
+ hash_fn=hash_fn,
130
+ verbose=verbose)
131
+
132
+ files_a, files_b = map(list, zip(*data))
133
+
134
+ return PairedDataset(files_a=files_a,
135
+ files_b=files_b,
136
+ preload=config['preload'],
137
+ preload_size=config['preload_size'],
138
+ normalize_fn=normalize_fn,
139
+ transform_fn=transform_fn,
140
+ verbose=verbose)
datasets/datasets.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ A good way to maintain the datasets in a project is to create a soft link.
2
+ In that case, you simply set the dataset path to the current path in the config files.
evaluate_NIQE.m ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ %p = genpath("G:\RESEARCH\PHD\Motion_Deblurred\xyscannet\visualization_bidir\comparison\epoch3k\uni");
2
+ %gt = genpath("G:\RESEARCH\PHD\Motion_Deblurred\xyscannet\visualization_bidir\comparison\epoch3k\gt");
3
+
4
+ length_p = size(p,2);
5
+ path = {};
6
+ temp = [];
7
+ for i = 1:length_p
8
+ if p(i) ~= ';'
9
+ temp = [temp p(i)];
10
+ else
11
+ temp = [temp '\'];
12
+ path = [path ; temp];
13
+ temp = [];
14
+ end
15
+ end
16
+ clear p length_p temp;
17
+ length_gt = size(gt,2);
18
+ path_gt = {};
19
+ temp_gt = [];
20
+ for i = 1:length_gt
21
+ if gt(i) ~= ';'
22
+ temp_gt = [temp_gt gt(i)];
23
+ else
24
+ temp_gt = [temp_gt '\'];
25
+ path_gt = [path_gt ; temp_gt];
26
+ temp_gt = [];
27
+ end
28
+ end
29
+ clear gt length_gt temp_gt;
30
+
31
+ file_num = size(path,1);
32
+ total_niqe = 0;
33
+ n = 0;
34
+ for i = 1:file_num
35
+ file_path = path{i};
36
+ gt_file_path = path_gt{i};
37
+ img_path_list = dir(strcat(file_path,'*.png'));
38
+ gt_path_list = dir(strcat(gt_file_path,'*.png'));
39
+ img_num = length(img_path_list);
40
+ if img_num > 0
41
+ for j = 1:img_num
42
+ image_name = img_path_list(j).name;
43
+ gt_name = gt_path_list(j).name;
44
+ image = imread(strcat(file_path,image_name));
45
+ gt = imread(strcat(gt_file_path,gt_name));
46
+ size(image);
47
+ size(gt);
48
+ cur_niqe = niqe(image);
49
+ fprintf('%d', cur_niqe);
50
+ total_niqe = total_niqe + cur_niqe;
51
+ n = n + 1
52
+ end
53
+ end
54
+ end
55
+ niqe_score = total_niqe / n
56
+ close all;clear all;
57
+
evaluate_RealBlur_J.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from skimage import io
3
+ import cv2
4
+ import numpy as np
5
+ from skimage.metrics import structural_similarity
6
+ import concurrent.futures
7
+
8
+ def image_align(deblurred, gt):
9
+ # this function is based on kohler evaluation code
10
+ z = deblurred
11
+ c = np.ones_like(z)
12
+ x = gt
13
+
14
+ zs = (np.sum(x * z) / np.sum(z * z)) * z # simple intensity matching
15
+
16
+ warp_mode = cv2.MOTION_HOMOGRAPHY
17
+ warp_matrix = np.eye(3, 3, dtype=np.float32)
18
+
19
+ # Specify the number of iterations.
20
+ number_of_iterations = 100
21
+
22
+ termination_eps = 0
23
+
24
+ criteria = (cv2.TERM_CRITERIA_EPS | cv2.TERM_CRITERIA_COUNT,
25
+ number_of_iterations, termination_eps)
26
+
27
+ # Run the ECC algorithm. The results are stored in warp_matrix.
28
+ (cc, warp_matrix) = cv2.findTransformECC(cv2.cvtColor(x, cv2.COLOR_RGB2GRAY), cv2.cvtColor(zs, cv2.COLOR_RGB2GRAY),
29
+ warp_matrix, warp_mode, criteria, inputMask=None, gaussFiltSize=5)
30
+
31
+ target_shape = x.shape
32
+ shift = warp_matrix
33
+
34
+ zr = cv2.warpPerspective(
35
+ zs,
36
+ warp_matrix,
37
+ (target_shape[1], target_shape[0]),
38
+ flags=cv2.INTER_CUBIC + cv2.WARP_INVERSE_MAP,
39
+ borderMode=cv2.BORDER_REFLECT)
40
+
41
+ cr = cv2.warpPerspective(
42
+ np.ones_like(zs, dtype='float32'),
43
+ warp_matrix,
44
+ (target_shape[1], target_shape[0]),
45
+ flags=cv2.INTER_NEAREST + cv2.WARP_INVERSE_MAP,
46
+ borderMode=cv2.BORDER_CONSTANT,
47
+ borderValue=0)
48
+
49
+ zr = zr * cr
50
+ xr = x * cr
51
+
52
+ return zr, xr, cr, shift
53
+
54
+ def compute_psnr(image_true, image_test, image_mask, data_range=None):
55
+ # this function is based on skimage.metrics.peak_signal_noise_ratio
56
+ err = np.sum((image_true - image_test) ** 2, dtype=np.float64) / np.sum(image_mask)
57
+ return 10 * np.log10((data_range ** 2) / err)
58
+
59
+
60
+ def compute_ssim(tar_img, prd_img, cr1):
61
+ ssim_pre, ssim_map = structural_similarity(tar_img, prd_img, channel_axis=2, gaussian_weights=True,
62
+ use_sample_covariance=False, data_range=1.0, full=True)
63
+ ssim_map = ssim_map * cr1
64
+ r = int(3.5 * 1.5 + 0.5) # radius as in ndimage
65
+ win_size = 2 * r + 1
66
+ pad = (win_size - 1) // 2
67
+ ssim = ssim_map[pad:-pad, pad:-pad, :]
68
+ crop_cr1 = cr1[pad:-pad, pad:-pad, :]
69
+ ssim = ssim.sum(axis=0).sum(axis=0) / crop_cr1.sum(axis=0).sum(axis=0)
70
+ ssim = np.mean(ssim)
71
+ return ssim
72
+
73
+ total_psnr = 0.
74
+ total_ssim = 0.
75
+ count = 0
76
+ #img_path = '/mnt/g/RESEARCH/PHD/Motion_Deblurred/xyscannet/ablation/v33/run1/images_realj'
77
+ #img_path = '/mnt/g/RESEARCH/PHD/Motion_Deblurred/xyscannet/sota/algnet/images/RealBlur_J'
78
+ #img_path = '/mnt/g/RESEARCH/PHD/Motion_Deblurred/xyscannet/sota/deeprft/FMIMOUNetPLUS_RealBlur/RealBlur_J--'
79
+ #img_path = '/mnt/g/RESEARCH/PHD/Motion_Deblurred/xyscannet/sota/deeprft/images_author/RealBlur_J'
80
+ #img_path = '/mnt/g/RESEARCH/PHD/Motion_Deblurred/xyscannet/sota/mprnet/images/RealBlur_J'
81
+ #img_path = '/mnt/g/RESEARCH/PHD/Motion_Deblurred/xyscannet/sota/stripformer/images/RealBlur_J'
82
+ img_path = '/mnt/g/RESEARCH/PHD/Motion_Deblurred/xyscannet/sota/xyscannetp/images/realj_final_stage3'
83
+
84
+ gt_path = '/mnt/g/RESEARCH/PHD/Motion_Deblurred/datasets/Realblur_J/test/testB'
85
+
86
+ print(img_path)
87
+ for file in os.listdir(img_path):
88
+ #for img_name in os.listdir(img_path + '/' + file):
89
+ img_name = file
90
+ count += 1
91
+ number = img_name.split('_')[1]
92
+ #number = img_name.split('-')[1]
93
+ #gt_name = 'gt_' + number
94
+ img_dir = img_path + '/' + file
95
+ s = file.split('_')
96
+ #s = file.split('-')
97
+ #gt_file = s[0] + '_' + 'gt_' + number
98
+ gt_file = '_'.join([s[0], 'gt', s[-1]])
99
+ gt_dir = gt_path + '/' + gt_file
100
+ print(gt_file)
101
+ print(img_dir)
102
+ with concurrent.futures.ProcessPoolExecutor(max_workers=10) as executor:
103
+ tar_img = io.imread(gt_dir)
104
+ prd_img = io.imread(img_dir)
105
+ tar_img = tar_img.astype(np.float32) / 255.0
106
+ prd_img = prd_img.astype(np.float32) / 255.0
107
+ prd_img, tar_img, cr1, shift = image_align(prd_img, tar_img)
108
+ PSNR = compute_psnr(tar_img, prd_img, cr1, data_range=1)
109
+ SSIM = compute_ssim(tar_img, prd_img, cr1)
110
+ total_psnr += PSNR
111
+ total_ssim += SSIM
112
+ print(count, PSNR)
113
+
114
+ print('PSNR:', total_psnr / count)
115
+ print('SSIM:', total_ssim / count)
116
+ print(img_path)
117
+
evaluate_RealBlur_R.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from skimage import io
3
+ import cv2
4
+ import numpy as np
5
+ from skimage.metrics import structural_similarity
6
+ import concurrent.futures
7
+
8
+ def image_align(deblurred, gt):
9
+ # this function is based on kohler evaluation code
10
+ z = deblurred
11
+ c = np.ones_like(z)
12
+ x = gt
13
+
14
+ zs = (np.sum(x * z) / np.sum(z * z)) * z # simple intensity matching
15
+
16
+ warp_mode = cv2.MOTION_HOMOGRAPHY
17
+ warp_matrix = np.eye(3, 3, dtype=np.float32)
18
+
19
+ # Specify the number of iterations.
20
+ number_of_iterations = 100
21
+
22
+ termination_eps = 0
23
+
24
+ criteria = (cv2.TERM_CRITERIA_EPS | cv2.TERM_CRITERIA_COUNT,
25
+ number_of_iterations, termination_eps)
26
+
27
+ # Run the ECC algorithm. The results are stored in warp_matrix.
28
+ (cc, warp_matrix) = cv2.findTransformECC(cv2.cvtColor(x, cv2.COLOR_RGB2GRAY), cv2.cvtColor(zs, cv2.COLOR_RGB2GRAY),
29
+ warp_matrix, warp_mode, criteria, inputMask=None, gaussFiltSize=5)
30
+
31
+ target_shape = x.shape
32
+ shift = warp_matrix
33
+
34
+ zr = cv2.warpPerspective(
35
+ zs,
36
+ warp_matrix,
37
+ (target_shape[1], target_shape[0]),
38
+ flags=cv2.INTER_CUBIC + cv2.WARP_INVERSE_MAP,
39
+ borderMode=cv2.BORDER_REFLECT)
40
+
41
+ cr = cv2.warpPerspective(
42
+ np.ones_like(zs, dtype='float32'),
43
+ warp_matrix,
44
+ (target_shape[1], target_shape[0]),
45
+ flags=cv2.INTER_NEAREST + cv2.WARP_INVERSE_MAP,
46
+ borderMode=cv2.BORDER_CONSTANT,
47
+ borderValue=0)
48
+
49
+ zr = zr * cr
50
+ xr = x * cr
51
+
52
+ return zr, xr, cr, shift
53
+
54
+ def compute_psnr(image_true, image_test, image_mask, data_range=None):
55
+ # this function is based on skimage.metrics.peak_signal_noise_ratio
56
+ err = np.sum((image_true - image_test) ** 2, dtype=np.float64) / np.sum(image_mask)
57
+ return 10 * np.log10((data_range ** 2) / err)
58
+
59
+
60
+ def compute_ssim(tar_img, prd_img, cr1):
61
+ ssim_pre, ssim_map = structural_similarity(tar_img, prd_img, channel_axis=2, gaussian_weights=True,
62
+ use_sample_covariance=False, data_range=1.0, full=True)
63
+ ssim_map = ssim_map * cr1
64
+ r = int(3.5 * 1.5 + 0.5) # radius as in ndimage
65
+ win_size = 2 * r + 1
66
+ pad = (win_size - 1) // 2
67
+ ssim = ssim_map[pad:-pad, pad:-pad, :]
68
+ crop_cr1 = cr1[pad:-pad, pad:-pad, :]
69
+ ssim = ssim.sum(axis=0).sum(axis=0) / crop_cr1.sum(axis=0).sum(axis=0)
70
+ ssim = np.mean(ssim)
71
+ return ssim
72
+
73
+ total_psnr = 0.
74
+ total_ssim = 0.
75
+ count = 0
76
+ #gt_path = '/mnt/g/RESEARCH/PHD/Motion_Deblurred/datasets/Realblur_R/test/testB'
77
+ img_path = '/mnt/g/RESEARCH/PHD/Motion_Deblurred/xyscannet/sota/deeprft/FMIMOUNetPLUS_RealBlur/RealBlur_R__'
78
+ #img_path = '/mnt/g/RESEARCH/PHD/Motion_Deblurred/xyscannet/sota/stripformer/images/RealBlur_R'
79
+ #img_path = '/mnt/g/RESEARCH/PHD/Motion_Deblurred/xyscannet/sota/deeprft/images/RealBlur_R'
80
+
81
+
82
+ gt_path = '/mnt/g/RESEARCH/PHD/Motion_Deblurred/datasets/Realblur_R/test/testB'
83
+ print(img_path)
84
+ for file in os.listdir(img_path):
85
+ #for img_name in os.listdir(img_path + '/' + file):
86
+ img_name = file
87
+ count += 1
88
+ number = img_name.split('_')[1]
89
+ #gt_name = 'gt_' + number
90
+ img_dir = img_path + '/' + file
91
+ s = file.split('_')
92
+ gt_file = '_'.join([s[0], 'gt', s[-1]])
93
+ gt_dir = gt_path + '/' + gt_file
94
+ print(img_dir)
95
+ with concurrent.futures.ProcessPoolExecutor(max_workers=10) as executor:
96
+ tar_img = io.imread(gt_dir)
97
+ prd_img = io.imread(img_dir)
98
+ tar_img = tar_img.astype(np.float32) / 255.0
99
+ prd_img = prd_img.astype(np.float32) / 255.0
100
+ prd_img, tar_img, cr1, shift = image_align(prd_img, tar_img)
101
+ PSNR = compute_psnr(tar_img, prd_img, cr1, data_range=1)
102
+ SSIM = compute_ssim(tar_img, prd_img, cr1)
103
+ total_psnr += PSNR
104
+ total_ssim += SSIM
105
+ print(count, PSNR)
106
+
107
+ print('PSNR:', total_psnr / count)
108
+ print('SSIM:', total_ssim / count)
109
+ print(img_path)
110
+
evaluation_GoPro.m ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ p = genpath('.\out\Stripformer_GoPro_results');% GoPro Deblur Results
2
+ gt = genpath('.\datasets\GoPro\test\sharp');% GoPro GT Results
3
+
4
+ length_p = size(p,2);
5
+ path = {};
6
+ temp = [];
7
+ for i = 1:length_p
8
+ if p(i) ~= ';'
9
+ temp = [temp p(i)];
10
+ else
11
+ temp = [temp '\'];
12
+ path = [path ; temp];
13
+ temp = [];
14
+ end
15
+ end
16
+ clear p length_p temp;
17
+ length_gt = size(gt,2);
18
+ path_gt = {};
19
+ temp_gt = [];
20
+ for i = 1:length_gt
21
+ if gt(i) ~= ';'
22
+ temp_gt = [temp_gt gt(i)];
23
+ else
24
+ temp_gt = [temp_gt '\'];
25
+ path_gt = [path_gt ; temp_gt];
26
+ temp_gt = [];
27
+ end
28
+ end
29
+ clear gt length_gt temp_gt;
30
+
31
+ file_num = size(path,1);
32
+ total_psnr = 0;
33
+ n = 0;
34
+ total_ssim = 0;
35
+ for i = 1:file_num
36
+ file_path = path{i};
37
+ gt_file_path = path_gt{i};
38
+ img_path_list = dir(strcat(file_path,'*.png'));
39
+ gt_path_list = dir(strcat(gt_file_path,'*.png'));
40
+ img_num = length(img_path_list);
41
+ if img_num > 0
42
+ for j = 1:img_num
43
+ image_name = img_path_list(j).name;
44
+ gt_name = gt_path_list(j).name;
45
+ image = imread(strcat(file_path,image_name));
46
+ gt = imread(strcat(gt_file_path,gt_name));
47
+ size(image);
48
+ size(gt);
49
+ peaksnr = psnr(image,gt);
50
+ ssimval = ssim(image,gt);
51
+ total_psnr = total_psnr + peaksnr;
52
+ total_ssim = total_ssim + ssimval;
53
+ n = n + 1
54
+ end
55
+ end
56
+ end
57
+ psnr = total_psnr / n
58
+ ssim = total_ssim / n
59
+ close all;clear all;
60
+
evaluation_HIDE.m ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ p = genpath('.\out\Stripformer_HIDE_results');% HIDE Deblur Results
2
+ gt = genpath('.\datasets\HIDE\sharp');% HIDE GT Results
3
+
4
+ length_p = size(p,2);
5
+ path = {};
6
+ temp = [];
7
+ for i = 1:length_p
8
+ if p(i) ~= ';'
9
+ temp = [temp p(i)];
10
+ else
11
+ temp = [temp '\'];
12
+ path = [path ; temp];
13
+ temp = [];
14
+ end
15
+ end
16
+ clear p length_p temp;
17
+ length_gt = size(gt,2);
18
+ path_gt = {};
19
+ temp_gt = [];
20
+ for i = 1:length_gt
21
+ if gt(i) ~= ';'
22
+ temp_gt = [temp_gt gt(i)];
23
+ else
24
+ temp_gt = [temp_gt '\'];
25
+ path_gt = [path_gt ; temp_gt];
26
+ temp_gt = [];
27
+ end
28
+ end
29
+ clear gt length_gt temp_gt;
30
+
31
+ file_num = size(path,1);
32
+ total_psnr = 0;
33
+ n = 0;
34
+ total_ssim = 0;
35
+ for i = 1:file_num
36
+ file_path = path{i};
37
+ gt_file_path = path_gt{i};
38
+ img_path_list = dir(strcat(file_path,'*.png'));
39
+ gt_path_list = dir(strcat(gt_file_path,'*.png'));
40
+ img_num = length(img_path_list);
41
+ if img_num > 0
42
+ for j = 1:img_num
43
+ image_name = img_path_list(j).name;
44
+ gt_name = gt_path_list(j).name;
45
+ image = imread(strcat(file_path,image_name));
46
+ gt = imread(strcat(gt_file_path,gt_name));
47
+ size(image);
48
+ size(gt);
49
+ peaksnr = psnr(image,gt);
50
+ ssimval = ssim(image,gt);
51
+ total_psnr = total_psnr + peaksnr;
52
+ total_ssim = total_ssim + ssimval;
53
+ n = n + 1
54
+ end
55
+ end
56
+ end
57
+ psnr = total_psnr / n
58
+ ssim = total_ssim / n
59
+ close all;clear all;
60
+
examples/blur1.png ADDED

Git LFS Details

  • SHA256: 7a39f365845b4b6c77882971bba18fae77a28bf410c253a19c026e43e5949207
  • Pointer size: 132 Bytes
  • Size of remote file: 1.04 MB
examples/blur2.png ADDED

Git LFS Details

  • SHA256: e692be16467fbcaf8051a9fe09504b84dc8388b0579d7e5cbbac718ee0ac5f2a
  • Pointer size: 131 Bytes
  • Size of remote file: 961 kB
examples/blur3.png ADDED

Git LFS Details

  • SHA256: 50cf03fc83fbc6ac2e92c1daf0b3580abbc206e136dc032c9a1e248647628a44
  • Pointer size: 131 Bytes
  • Size of remote file: 841 kB
examples/blur4.png ADDED

Git LFS Details

  • SHA256: f9e7dfe63e11c57711881ca9b04ed335a7f41eb34e68e52ee8259dfd0f7c32bf
  • Pointer size: 131 Bytes
  • Size of remote file: 744 kB
examples/blur5.png ADDED

Git LFS Details

  • SHA256: 788fab8dca7aa0440511a15418f4bbb6b559ac2b3c22767b40dfd1bd25621562
  • Pointer size: 131 Bytes
  • Size of remote file: 962 kB
license ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ Copyright (c) 2025 Hanzhou Liu
3
+
4
+ Permission is hereby granted, free of charge, to any person obtaining a copy
5
+ of this software and associated documentation files (the "Software"), to deal
6
+ in the Software with non-commercial usage, including non-commercial usage
7
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
8
+ copies of the Software, and to permit persons to whom the Software is
9
+ furnished to do so, subject to the following conditions:
10
+
11
+ The above copyright notice and this permission notice shall be included in all
12
+ copies or substantial portions of the Software.
13
+
14
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
17
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
19
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
20
+ SOFTWARE.
21
+
22
+ --------------------------- LICENSE FOR DeblurGANv2 --------------------------------
23
+ BSD License
24
+
25
+ For DeblurGANv2 software
26
+ Copyright (c) 2019, Orest Kupyn, Tetiana Martyniuk, Junru Wu and Zhangyang Wang
27
+ All rights reserved.
28
+
29
+ Redistribution and use in source and binary forms, with or without
30
+ modification, are permitted provided that the following conditions are met:
31
+
32
+ * Redistributions of source code must retain the above copyright notice, this
33
+ list of conditions and the following disclaimer.
34
+
35
+ * Redistributions in binary form must reproduce the above copyright notice,
36
+ this list of conditions and the following disclaimer in the documentation
37
+ and/or other materials provided with the distribution.
metric_counter.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from collections import defaultdict
3
+
4
+ import numpy as np
5
+ from tensorboardX import SummaryWriter
6
+
7
+ WINDOW_SIZE = 100
8
+
9
+
10
+ class MetricCounter:
11
+ def __init__(self, exp_name):
12
+ self.writer = SummaryWriter(exp_name)
13
+ logging.basicConfig(filename='{}.log'.format(exp_name), level=logging.DEBUG)
14
+ self.metrics = defaultdict(list)
15
+ self.images = defaultdict(list)
16
+ self.best_metric = 0
17
+
18
+ def add_image(self, x: np.ndarray, tag: str):
19
+ self.images[tag].append(x)
20
+
21
+ def clear(self):
22
+ self.metrics = defaultdict(list)
23
+ self.images = defaultdict(list)
24
+
25
+ def add_losses(self, l_G):
26
+ for name, value in zip(('G_loss', None), (l_G, None)):
27
+ self.metrics[name].append(value)
28
+
29
+ def add_metrics(self, psnr, ssim):
30
+ for name, value in zip(('PSNR', 'SSIM'),
31
+ (psnr, ssim)):
32
+ self.metrics[name].append(value)
33
+
34
+ def loss_message(self):
35
+ metrics = ((k, np.mean(self.metrics[k][-WINDOW_SIZE:])) for k in ('G_loss', 'PSNR', 'SSIM'))
36
+ return '; '.join(map(lambda x: f'{x[0]}={x[1]:.4f}', metrics))
37
+
38
+ def write_to_tensorboard(self, epoch_num, validation=False):
39
+ scalar_prefix = 'Validation' if validation else 'Train'
40
+ for tag in ('G_loss', 'SSIM', 'PSNR'):
41
+ self.writer.add_scalar(f'{scalar_prefix}_{tag}', np.mean(self.metrics[tag]), global_step=epoch_num)
42
+ for tag in self.images:
43
+ imgs = self.images[tag]
44
+ if imgs:
45
+ imgs = np.array(imgs)
46
+ self.writer.add_images(tag, imgs[:, :, :, ::-1].astype('float32') / 255, dataformats='NHWC',
47
+ global_step=epoch_num)
48
+ self.images[tag] = []
49
+
50
+ def update_best_model(self):
51
+ cur_metric = np.mean(self.metrics['PSNR'])
52
+ if self.best_metric < cur_metric:
53
+ self.best_metric = cur_metric
54
+ return True
55
+ return False
models/XYScanNet.py ADDED
@@ -0,0 +1,737 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numbers
2
+ import math
3
+ from typing import Optional
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from torch import Tensor
9
+
10
+ from einops import rearrange, repeat
11
+
12
+ from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn
13
+
14
+ try:
15
+ from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
16
+ except ImportError:
17
+ causal_conv1d_fn, causal_conv1d_update = None, None
18
+
19
+ try:
20
+ from mamba_ssm.ops.triton.selective_state_update import selective_state_update
21
+ except ImportError:
22
+ selective_state_update = None
23
+
24
+ try:
25
+ from mamba_ssm.ops.triton.layer_norm import RMSNorm, layer_norm_fn, rms_norm_fn
26
+ except ImportError:
27
+ RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None
28
+
29
+
30
+ def to_3d(x):
31
+ return rearrange(x, 'b c h w -> b (h w) c')
32
+
33
+
34
+ def to_4d(x, h, w):
35
+ return rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
36
+
37
+
38
+ class BiasFree_LayerNorm(nn.Module):
39
+ def __init__(self, normalized_shape):
40
+ super(BiasFree_LayerNorm, self).__init__()
41
+ if isinstance(normalized_shape, numbers.Integral):
42
+ normalized_shape = (normalized_shape,)
43
+ normalized_shape = torch.Size(normalized_shape)
44
+
45
+ assert len(normalized_shape) == 1
46
+
47
+ self.weight = nn.Parameter(torch.ones(normalized_shape))
48
+ self.normalized_shape = normalized_shape
49
+
50
+ def forward(self, x):
51
+ sigma = x.var(-1, keepdim=True, unbiased=False)
52
+ return x / torch.sqrt(sigma + 1e-5) * self.weight
53
+
54
+
55
+ class WithBias_LayerNorm(nn.Module):
56
+ def __init__(self, normalized_shape):
57
+ super(WithBias_LayerNorm, self).__init__()
58
+ if isinstance(normalized_shape, numbers.Integral):
59
+ normalized_shape = (normalized_shape,)
60
+ normalized_shape = torch.Size(normalized_shape)
61
+
62
+ assert len(normalized_shape) == 1
63
+
64
+ self.weight = nn.Parameter(torch.ones(normalized_shape))
65
+ self.bias = nn.Parameter(torch.zeros(normalized_shape))
66
+ self.normalized_shape = normalized_shape
67
+
68
+ def forward(self, x):
69
+ mu = x.mean(-1, keepdim=True)
70
+ sigma = x.var(-1, keepdim=True, unbiased=False)
71
+ return (x - mu) / torch.sqrt(sigma + 1e-5) * self.weight + self.bias
72
+
73
+
74
+ class LayerNorm(nn.Module):
75
+ def __init__(self, dim, LayerNorm_type):
76
+ super(LayerNorm, self).__init__()
77
+ if LayerNorm_type == 'BiasFree':
78
+ self.body = BiasFree_LayerNorm(dim)
79
+ else:
80
+ self.body = WithBias_LayerNorm(dim)
81
+
82
+ def forward(self, x):
83
+ h, w = x.shape[-2:]
84
+ return to_4d(self.body(to_3d(x)), h, w)
85
+
86
+ ##########################################################################
87
+ def conv(in_channels, out_channels, kernel_size, bias=False, stride = 1):
88
+ return nn.Conv2d(
89
+ in_channels, out_channels, kernel_size,
90
+ padding=(kernel_size//2), bias=bias, stride = stride)
91
+
92
+
93
+ """
94
+ Borrow from "https://github.com/state-spaces/mamba.git"
95
+ @article{mamba,
96
+ title={Mamba: Linear-Time Sequence Modeling with Selective State Spaces},
97
+ author={Gu, Albert and Dao, Tri},
98
+ journal={arXiv preprint arXiv:2312.00752},
99
+ year={2023}
100
+ }
101
+ """
102
+ class Mamba(nn.Module):
103
+ def __init__(
104
+ self,
105
+ d_model,
106
+ d_state=16,
107
+ d_conv=4,
108
+ expand=2,
109
+ dt_rank="auto",
110
+ dt_min=0.001,
111
+ dt_max=0.1,
112
+ dt_init="random",
113
+ dt_scale=1.0,
114
+ dt_init_floor=1e-4,
115
+ conv_bias=True,
116
+ bias=False,
117
+ use_fast_path=True, # Fused kernel options
118
+ layer_idx=None,
119
+ device=None,
120
+ dtype=None,
121
+ ):
122
+ factory_kwargs = {"device": device, "dtype": dtype}
123
+ super().__init__()
124
+ self.d_model = d_model
125
+ self.d_state = d_state
126
+ self.d_conv = d_conv
127
+ self.expand = expand
128
+ self.d_inner = int(self.expand * self.d_model)
129
+ self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank
130
+ self.use_fast_path = use_fast_path
131
+ self.layer_idx = layer_idx
132
+
133
+ self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=bias, **factory_kwargs)
134
+
135
+ self.conv1d = nn.Conv1d(
136
+ in_channels=self.d_inner,
137
+ out_channels=self.d_inner,
138
+ bias=conv_bias,
139
+ kernel_size=d_conv,
140
+ groups=self.d_inner,
141
+ padding=d_conv - 1,
142
+ **factory_kwargs,
143
+ )
144
+
145
+ self.activation = "silu"
146
+ self.act = nn.SiLU()
147
+
148
+ self.x_proj = nn.Linear(
149
+ self.d_inner, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs
150
+ )
151
+ self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True, **factory_kwargs)
152
+
153
+ # Initialize special dt projection to preserve variance at initialization
154
+ dt_init_std = self.dt_rank**-0.5 * dt_scale
155
+ if dt_init == "constant":
156
+ nn.init.constant_(self.dt_proj.weight, dt_init_std)
157
+ elif dt_init == "random":
158
+ nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std)
159
+ else:
160
+ raise NotImplementedError
161
+
162
+ # Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max
163
+ dt = torch.exp(
164
+ torch.rand(self.d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min))
165
+ + math.log(dt_min)
166
+ ).clamp(min=dt_init_floor)
167
+ # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
168
+ inv_dt = dt + torch.log(-torch.expm1(-dt))
169
+ with torch.no_grad():
170
+ self.dt_proj.bias.copy_(inv_dt)
171
+ # Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit
172
+ self.dt_proj.bias._no_reinit = True
173
+
174
+ # S4D real initialization
175
+ A = repeat(
176
+ torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device),
177
+ "n -> d n",
178
+ d=self.d_inner,
179
+ ).contiguous()
180
+ A_log = torch.log(A) # Keep A_log in fp32
181
+ self.A_log = nn.Parameter(A_log)
182
+ self.A_log._no_weight_decay = True
183
+
184
+ # D "skip" parameter
185
+ self.D = nn.Parameter(torch.ones(self.d_inner, device=device)) # Keep in fp32
186
+ self.D._no_weight_decay = True
187
+
188
+ self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs)
189
+
190
+ def forward(self, hidden_states, inference_params=None):
191
+ """
192
+ hidden_states: (B, L, D)
193
+ Returns: same shape as hidden_states
194
+ """
195
+ batch, seqlen, dim = hidden_states.shape
196
+
197
+ conv_state, ssm_state = None, None
198
+ if inference_params is not None:
199
+ conv_state, ssm_state = self._get_states_from_cache(inference_params, batch)
200
+ if inference_params.seqlen_offset > 0:
201
+ # The states are updated inplace
202
+ out, _, _ = self.step(hidden_states, conv_state, ssm_state)
203
+ return out
204
+
205
+ # We do matmul and transpose BLH -> HBL at the same time
206
+ xz = rearrange(
207
+ self.in_proj.weight @ rearrange(hidden_states, "b l d -> d (b l)"),
208
+ "d (b l) -> b d l",
209
+ l=seqlen,
210
+ )
211
+ if self.in_proj.bias is not None:
212
+ xz = xz + rearrange(self.in_proj.bias.to(dtype=xz.dtype), "d -> d 1")
213
+
214
+ A = -torch.exp(self.A_log.float()) # (d_inner, d_state)
215
+ # In the backward pass we write dx and dz next to each other to avoid torch.cat
216
+ if self.use_fast_path and causal_conv1d_fn is not None and inference_params is None: # Doesn't support outputting the states
217
+ out = mamba_inner_fn(
218
+ xz,
219
+ self.conv1d.weight,
220
+ self.conv1d.bias,
221
+ self.x_proj.weight,
222
+ self.dt_proj.weight,
223
+ self.out_proj.weight,
224
+ self.out_proj.bias,
225
+ A,
226
+ None, # input-dependent B
227
+ None, # input-dependent C
228
+ self.D.float(),
229
+ delta_bias=self.dt_proj.bias.float(),
230
+ delta_softplus=True,
231
+ )
232
+ else:
233
+ x, z = xz.chunk(2, dim=1)
234
+ # Compute short convolution
235
+ if conv_state is not None:
236
+ # If we just take x[:, :, -self.d_conv :], it will error if seqlen < self.d_conv
237
+ # Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise.
238
+ conv_state.copy_(F.pad(x, (self.d_conv - x.shape[-1], 0))) # Update state (B D W)
239
+ if causal_conv1d_fn is None:
240
+ x = self.act(self.conv1d(x)[..., :seqlen])
241
+ else:
242
+ assert self.activation in ["silu", "swish"]
243
+ x = causal_conv1d_fn(
244
+ x=x,
245
+ weight=rearrange(self.conv1d.weight, "d 1 w -> d w"),
246
+ bias=self.conv1d.bias,
247
+ activation=self.activation,
248
+ )
249
+
250
+ # We're careful here about the layout, to avoid extra transposes.
251
+ # We want dt to have d as the slowest moving dimension
252
+ # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
253
+ x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d")) # (bl d)
254
+ dt, B, C = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1)
255
+ dt = self.dt_proj.weight @ dt.t()
256
+ dt = rearrange(dt, "d (b l) -> b d l", l=seqlen)
257
+ B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
258
+ C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
259
+ assert self.activation in ["silu", "swish"]
260
+ y = selective_scan_fn(
261
+ x,
262
+ dt,
263
+ A,
264
+ B,
265
+ C,
266
+ self.D.float(),
267
+ z=z,
268
+ delta_bias=self.dt_proj.bias.float(),
269
+ delta_softplus=True,
270
+ return_last_state=ssm_state is not None,
271
+ )
272
+ if ssm_state is not None:
273
+ y, last_state = y
274
+ ssm_state.copy_(last_state)
275
+ y = rearrange(y, "b d l -> b l d")
276
+ out = self.out_proj(y)
277
+ return out
278
+
279
+ def step(self, hidden_states, conv_state, ssm_state):
280
+ dtype = hidden_states.dtype
281
+ assert hidden_states.shape[1] == 1, "Only support decoding with 1 token at a time for now"
282
+ xz = self.in_proj(hidden_states.squeeze(1)) # (B 2D)
283
+ x, z = xz.chunk(2, dim=-1) # (B D)
284
+
285
+ # Conv step
286
+ if causal_conv1d_update is None:
287
+ conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W)
288
+ conv_state[:, :, -1] = x
289
+ x = torch.sum(conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1) # (B D)
290
+ if self.conv1d.bias is not None:
291
+ x = x + self.conv1d.bias
292
+ x = self.act(x).to(dtype=dtype)
293
+ else:
294
+ x = causal_conv1d_update(
295
+ x,
296
+ conv_state,
297
+ rearrange(self.conv1d.weight, "d 1 w -> d w"),
298
+ self.conv1d.bias,
299
+ self.activation,
300
+ )
301
+
302
+ x_db = self.x_proj(x) # (B dt_rank+2*d_state)
303
+ dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1)
304
+ # Don't add dt_bias here
305
+ dt = F.linear(dt, self.dt_proj.weight) # (B d_inner)
306
+ A = -torch.exp(self.A_log.float()) # (d_inner, d_state)
307
+
308
+ # SSM step
309
+ if selective_state_update is None:
310
+ # Discretize A and B
311
+ dt = F.softplus(dt + self.dt_proj.bias.to(dtype=dt.dtype))
312
+ dA = torch.exp(torch.einsum("bd,dn->bdn", dt, A))
313
+ dB = torch.einsum("bd,bn->bdn", dt, B)
314
+ ssm_state.copy_(ssm_state * dA + rearrange(x, "b d -> b d 1") * dB)
315
+ y = torch.einsum("bdn,bn->bd", ssm_state.to(dtype), C)
316
+ y = y + self.D.to(dtype) * x
317
+ y = y * self.act(z) # (B D)
318
+ else:
319
+ y = selective_state_update(
320
+ ssm_state, x, dt, A, B, C, self.D, z=z, dt_bias=self.dt_proj.bias, dt_softplus=True
321
+ )
322
+
323
+ out = self.out_proj(y)
324
+ return out.unsqueeze(1), conv_state, ssm_state
325
+
326
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
327
+ device = self.out_proj.weight.device
328
+ conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype
329
+ conv_state = torch.zeros(
330
+ batch_size, self.d_model * self.expand, self.d_conv, device=device, dtype=conv_dtype
331
+ )
332
+ ssm_dtype = self.dt_proj.weight.dtype if dtype is None else dtype
333
+ # ssm_dtype = torch.float32
334
+ ssm_state = torch.zeros(
335
+ batch_size, self.d_model * self.expand, self.d_state, device=device, dtype=ssm_dtype
336
+ )
337
+ return conv_state, ssm_state
338
+
339
+ def _get_states_from_cache(self, inference_params, batch_size, initialize_states=False):
340
+ assert self.layer_idx is not None
341
+ if self.layer_idx not in inference_params.key_value_memory_dict:
342
+ batch_shape = (batch_size,)
343
+ conv_state = torch.zeros(
344
+ batch_size,
345
+ self.d_model * self.expand,
346
+ self.d_conv,
347
+ device=self.conv1d.weight.device,
348
+ dtype=self.conv1d.weight.dtype,
349
+ )
350
+ ssm_state = torch.zeros(
351
+ batch_size,
352
+ self.d_model * self.expand,
353
+ self.d_state,
354
+ device=self.dt_proj.weight.device,
355
+ dtype=self.dt_proj.weight.dtype,
356
+ # dtype=torch.float32,
357
+ )
358
+ inference_params.key_value_memory_dict[self.layer_idx] = (conv_state, ssm_state)
359
+ else:
360
+ conv_state, ssm_state = inference_params.key_value_memory_dict[self.layer_idx]
361
+ # TODO: What if batch size changes between generation, and we reuse the same states?
362
+ if initialize_states:
363
+ conv_state.zero_()
364
+ ssm_state.zero_()
365
+ return conv_state, ssm_state
366
+
367
+ ##########################################################################
368
+ ## Feed-forward Network
369
+ class FFN(nn.Module):
370
+ def __init__(self, dim, ffn_expansion_factor, bias):
371
+ super(FFN, self).__init__()
372
+
373
+ hidden_features = int(dim*ffn_expansion_factor)
374
+
375
+ self.project_in = nn.Conv2d(dim, hidden_features*2, kernel_size=1, bias=bias)
376
+
377
+ self.dwconv = nn.Conv2d(hidden_features*2, hidden_features*2, kernel_size=3, stride=1, padding=1, groups=hidden_features*2, bias=bias, dilation=1)
378
+
379
+ self.win_size = 8
380
+
381
+ self.modulator = nn.Parameter(torch.ones(self.win_size, self.win_size, dim*2)) # modulator
382
+
383
+ self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias)
384
+
385
+ def forward(self, x):
386
+ b, c, h, w = x.shape
387
+ h1, w1 = h//self.win_size, w//self.win_size
388
+ x = self.project_in(x)
389
+ x = self.dwconv(x)
390
+ x_win = rearrange(x, 'b c (wsh h1) (wsw w1) -> b h1 w1 wsh wsw c', wsh=self.win_size, wsw=self.win_size)
391
+ x_win = x_win * self.modulator
392
+ x = rearrange(x_win, 'b h1 w1 wsh wsw c -> b c (wsh h1) (wsw w1)', wsh=self.win_size, wsw=self.win_size, h1=h1, w1=w1)
393
+ x1, x2 = x.chunk(2, dim=1)
394
+ x = x1 * x2
395
+ x = self.project_out(x)
396
+ return x
397
+
398
+
399
+ ##########################################################################
400
+ ## Gated Depth-wise Feed-forward Network (GDFN)
401
+ class GDFN(nn.Module):
402
+ def __init__(self, dim, ffn_expansion_factor, bias):
403
+ super(GDFN, self).__init__()
404
+
405
+ hidden_features = int(dim*ffn_expansion_factor)
406
+
407
+ self.project_in = nn.Conv2d(dim, hidden_features*2, kernel_size=1, bias=bias)
408
+
409
+ self.dwconv = nn.Conv2d(hidden_features*2, hidden_features*2, kernel_size=3, stride=1, padding=1, groups=hidden_features*2, bias=bias, dilation=1)
410
+
411
+ self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias)
412
+
413
+ def forward(self, x):
414
+ x = self.project_in(x)
415
+ x = self.dwconv(x)
416
+ x1, x2 = x.chunk(2, dim=1)
417
+ x = F.silu(x1) * x2
418
+ x = self.project_out(x)
419
+ return x
420
+
421
+
422
+ ##########################################################################
423
+ ## Overlapped image patch embedding with 3x3 Conv
424
+ class OverlapPatchEmbed(nn.Module):
425
+ def __init__(self, in_c=3, embed_dim=48, bias=False):
426
+ super(OverlapPatchEmbed, self).__init__()
427
+
428
+ self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=3, stride=1, padding=1, bias=bias)
429
+
430
+ def forward(self, x):
431
+ x = self.proj(x)
432
+
433
+ return x
434
+
435
+
436
+ ##########################################################################
437
+ ## Resizing modules
438
+ class Downsample(nn.Module):
439
+ def __init__(self, n_feat):
440
+ super(Downsample, self).__init__()
441
+
442
+ self.body = nn.Sequential(nn.Upsample(scale_factor=0.5, mode='bilinear', align_corners=False),
443
+ nn.Conv2d(n_feat, n_feat * 2, 3, stride=1, padding=1, bias=False))
444
+
445
+ def forward(self, x):
446
+ return self.body(x)
447
+
448
+
449
+ class Upsample(nn.Module):
450
+ def __init__(self, n_feat):
451
+ super(Upsample, self).__init__()
452
+
453
+ self.body = nn.Sequential(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
454
+ nn.Conv2d(n_feat, n_feat // 2, 3, stride=1, padding=1, bias=False))
455
+
456
+ def forward(self, x):
457
+ return self.body(x)
458
+
459
+
460
+ """
461
+ Borrow from "https://github.com/pp00704831/Stripformer-ECCV-2022-.git"
462
+ @inproceedings{Tsai2022Stripformer,
463
+ author = {Fu-Jen Tsai and Yan-Tsung Peng and Yen-Yu Lin and Chung-Chi Tsai and Chia-Wen Lin},
464
+ title = {Stripformer: Strip Transformer for Fast Image Deblurring},
465
+ booktitle = {ECCV},
466
+ year = {2022}
467
+ }
468
+ """
469
+ class Intra_VSSM(nn.Module):
470
+ def __init__(self, dim, vssm_expansion_factor, bias): # gated = True
471
+ super(Intra_VSSM, self).__init__()
472
+ hidden = int(dim*vssm_expansion_factor)
473
+
474
+ self.proj_in = nn.Conv2d(dim, hidden*2, kernel_size=1, bias=bias)
475
+ self.dwconv = nn.Conv2d(hidden*2, hidden*2, kernel_size=3, stride=1, padding=1, groups=hidden*2, bias=bias)
476
+ self.proj_out = nn.Conv2d(hidden, dim, kernel_size=1, bias=bias)
477
+
478
+ self.conv_input = nn.Conv2d(hidden, hidden, kernel_size=1, padding=0, bias=bias)
479
+ self.fuse_out = nn.Conv2d(hidden, hidden, kernel_size=1, padding=0, bias=bias)
480
+ self.mamba = Mamba(d_model=hidden // 2)
481
+
482
+ def forward_core(self, x):
483
+ B, C, H, W = x.size()
484
+
485
+ x_input = torch.chunk(self.conv_input(x), 2, dim=1)
486
+
487
+ feature_h = (x_input[0]).permute(0, 2, 3, 1).contiguous()
488
+ feature_h = feature_h.view(B * H, W, C//2)
489
+
490
+ feature_v = (x_input[1]).permute(0, 3, 2, 1).contiguous()
491
+ feature_v = feature_v.view(B * W, H, C//2)
492
+
493
+ if H == W:
494
+ feature = torch.cat((feature_h, feature_v), dim=0) # B * H * 2, W, C//2
495
+ scan_output = self.mamba(feature)
496
+ scan_output = torch.chunk(scan_output, 2, dim=0)
497
+ scan_output_h = scan_output[0]
498
+ scan_output_v = scan_output[1]
499
+ else:
500
+ scan_output_h = self.mamba(feature_h)
501
+ scan_output_v = self.mamba(feature_v)
502
+
503
+ scan_output_h = scan_output_h.view(B, H, W, C//2).permute(0, 3, 1, 2).contiguous()
504
+ scan_output_v = scan_output_v.view(B, W, H, C//2).permute(0, 3, 2, 1).contiguous()
505
+ scan_output = self.fuse_out(torch.cat((scan_output_h, scan_output_v), dim=1))
506
+
507
+ return scan_output
508
+
509
+ def forward(self, x):
510
+ x = self.proj_in(x)
511
+ x, x_ = self.dwconv(x).chunk(2, dim=1)
512
+ x = self.forward_core(x)
513
+ x = F.silu(x_) * x
514
+ x = self.proj_out(x)
515
+ return x
516
+
517
+
518
+ class Inter_VSSM(nn.Module):
519
+ def __init__(self, dim, vssm_expansion_factor, bias): # gated = True
520
+ super(Inter_VSSM, self).__init__()
521
+ hidden = int(dim*vssm_expansion_factor)
522
+
523
+ self.proj_in = nn.Conv2d(dim, hidden*2, kernel_size=1, bias=bias)
524
+ self.dwconv = nn.Conv2d(hidden*2, hidden*2, kernel_size=3, stride=1, padding=1, groups=hidden*2, bias=bias)
525
+ self.proj_out = nn.Conv2d(hidden, dim, kernel_size=1, bias=bias)
526
+
527
+ self.avg_pool = nn.AdaptiveAvgPool2d((None,1))
528
+ self.conv_input = nn.Conv2d(hidden, hidden, kernel_size=1, padding=0, bias=bias)
529
+ self.fuse_out = nn.Conv2d(hidden, hidden, kernel_size=1, padding=0, bias=bias)
530
+ self.mamba = Mamba(d_model=hidden // 2)
531
+ self.sigmoid = nn.Sigmoid()
532
+
533
+ def forward_core(self, x):
534
+ B, C, H, W = x.size()
535
+
536
+ x_input = torch.chunk(self.conv_input(x), 2, dim=1) # B, C, H, W
537
+
538
+ feature_h = x_input[0].permute(0, 2, 1, 3).contiguous() # B, H, C//2, W
539
+ feature_h_score = self.avg_pool(feature_h) # B, H, C//2, 1
540
+ feature_h_score = feature_h_score.view(B, H, -1)
541
+
542
+ feature_v = x_input[1].permute(0, 3, 1, 2).contiguous() # B, W, C//2, H
543
+ feature_v_score = self.avg_pool(feature_v) # B, W, C//2, 1
544
+ feature_v_score = feature_v_score.view(B, W, -1)
545
+
546
+ if H == W:
547
+ feature_score = torch.cat((feature_h_score, feature_v_score), dim=0) # B * 2, W or H, C//2
548
+ scan_score = self.mamba(feature_score)
549
+ scan_score = torch.chunk(scan_score, 2, dim=0)
550
+ scan_score_h = scan_score[0]
551
+ scan_score_v = scan_score[1]
552
+ else:
553
+ scan_score_h = self.mamba(feature_h_score)
554
+ scan_score_v = self.mamba(feature_v_score)
555
+
556
+ scan_score_h = self.sigmoid(scan_score_h)
557
+ scan_score_v = self.sigmoid(scan_score_v)
558
+ feature_h = feature_h*scan_score_h[:,:,:,None]
559
+ feature_v = feature_v*scan_score_v[:,:,:,None]
560
+ feature_h = feature_h.view(B, H, C//2, W).permute(0, 2, 1, 3).contiguous()
561
+ feature_v = feature_v.view(B, W, C//2, H).permute(0, 2, 3, 1).contiguous()
562
+ output = self.fuse_out(torch.cat((feature_h, feature_v), dim=1))
563
+
564
+ return output
565
+
566
+ def forward(self, x):
567
+ x = self.proj_in(x)
568
+ x, x_ = self.dwconv(x).chunk(2, dim=1)
569
+ x = self.forward_core(x)
570
+ x = F.silu(x_) * x
571
+ x = self.proj_out(x)
572
+ return x
573
+
574
+
575
+ ##########################################################################
576
+ class Strip_VSSB(nn.Module):
577
+ def __init__(self, dim, vssm_expansion_factor, ffn_expansion_factor, bias=False, ssm=False, LayerNorm_type='WithBias'):
578
+ super(Strip_VSSB, self).__init__()
579
+ self.ssm = ssm
580
+ if self.ssm == True:
581
+ self.norm1_ssm = LayerNorm(dim, LayerNorm_type)
582
+ self.norm2_ssm = LayerNorm(dim, LayerNorm_type)
583
+ self.intra = Intra_VSSM(dim, vssm_expansion_factor, bias)
584
+ self.inter = Inter_VSSM(dim, vssm_expansion_factor, bias)
585
+ self.norm1_ffn = LayerNorm(dim, LayerNorm_type)
586
+ self.norm2_ffn = LayerNorm(dim, LayerNorm_type)
587
+ self.ffn1 = GDFN(dim, ffn_expansion_factor, bias)
588
+ self.ffn2 = GDFN(dim, ffn_expansion_factor, bias)
589
+
590
+ def forward(self, x):
591
+ if self.ssm == True:
592
+ x = x + self.intra(self.norm1_ssm(x))
593
+ x = x + self.ffn1(self.norm1_ffn(x))
594
+ if self.ssm == True:
595
+ x = x + self.inter(self.norm2_ssm(x))
596
+ x = x + self.ffn2(self.norm2_ffn(x))
597
+
598
+ return x
599
+
600
+
601
+ ##########################################################################
602
+ ##---------- Cross-level Feature Fusion by Adding Sigmoid(KL-Div) * Multi-Scale Feat -----------------------
603
+ class CLFF(nn.Module):
604
+ def __init__(self, dim, dim_n1, dim_n2, bias=False):
605
+ super(CLFF, self).__init__()
606
+
607
+ self.conv = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
608
+ self.conv_n1 = nn.Conv2d(dim_n1, dim, kernel_size=1, bias=bias)
609
+ self.conv_n2 = nn.Conv2d(dim_n2, dim, kernel_size=1, bias=bias)
610
+ self.fuse_out1 = nn.Conv2d(dim*2, dim, kernel_size=1, bias=bias)
611
+
612
+ self.log_sigmoid = nn.LogSigmoid()
613
+ self.sigmoid = nn.Sigmoid()
614
+
615
+ def forward(self, x, n1, n2):
616
+ x_ = self.conv(x)
617
+ n1_ = self.conv_n1(n1)
618
+ n2_ = self.conv_n2(n2)
619
+ kl_n1 = F.kl_div(input=self.log_sigmoid(n1_), target=self.log_sigmoid(x_), log_target=True)
620
+ kl_n2 = F.kl_div(input=self.log_sigmoid(n2_), target=self.log_sigmoid(x_), log_target=True)
621
+ #g = self.sigmoid(x_)
622
+ g1 = self.sigmoid(kl_n1)
623
+ g2 = self.sigmoid(kl_n2)
624
+ #x = (1 + g) * x_ + (1 - g) * (g1 * n1_ + g2 * n2_)
625
+ x = self.fuse_out1(torch.cat((x_, g1 * n1_ + g2 * n2_), dim=1))
626
+
627
+ return x
628
+
629
+ ##########################################################################
630
+ ##---------- StripScanNet -----------------------
631
+ class XYScanNet(nn.Module):
632
+ def __init__(self,
633
+ inp_channels=3,
634
+ out_channels=3,
635
+ dim = 72, # 48, 72, 96, 120, 144
636
+ num_blocks = [3,3,6],
637
+ vssm_expansion_factor = 1, # 1 or 2
638
+ ffn_expansion_factor = 1, # 1 or 3
639
+ bias = False,
640
+ LayerNorm_type = 'WithBias', ## Other option 'BiasFree'
641
+ ):
642
+
643
+ super(XYScanNet, self).__init__()
644
+
645
+ self.patch_embed = OverlapPatchEmbed(inp_channels, dim)
646
+
647
+ self.encoder_level1 = nn.Sequential(*[Strip_VSSB(dim=dim, vssm_expansion_factor=vssm_expansion_factor, ffn_expansion_factor = ffn_expansion_factor,
648
+ bias=bias, ssm=False, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])])
649
+
650
+ self.down1_2 = Downsample(dim) ## From Level 1 to Level 2
651
+ self.encoder_level2 = nn.Sequential(*[Strip_VSSB(dim=int(dim*2**1), vssm_expansion_factor=vssm_expansion_factor, ffn_expansion_factor = ffn_expansion_factor,
652
+ bias=bias, ssm=False, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[1])])
653
+
654
+ self.down2_3 = Downsample(int(dim*2**1)) ## From Level 2 to Level 3
655
+ self.encoder_level3 = nn.Sequential(*[Strip_VSSB(dim=int(dim*2**2), vssm_expansion_factor=vssm_expansion_factor, ffn_expansion_factor = ffn_expansion_factor,
656
+ bias=bias, ssm=False, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[2])])
657
+
658
+ self.decoder_level3 = nn.Sequential(*[Strip_VSSB(dim=int(dim*2**2), vssm_expansion_factor=vssm_expansion_factor, ffn_expansion_factor = ffn_expansion_factor,
659
+ bias=bias, ssm=True, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[2])])
660
+
661
+ self.up3_2 = Upsample(int(dim*2**2)) ## From Level 3 to Level 2
662
+ self.clff_level2 = CLFF(int(dim*2**1), dim_n1=int(dim*2**0), dim_n2=(dim*2**2), bias=bias)
663
+ self.reduce_chan_level2 = nn.Conv2d(int(dim*2**2), int(dim*2**1), kernel_size=1, bias=bias)
664
+ self.decoder_level2 = nn.Sequential(*[Strip_VSSB(dim=int(dim*2**1), vssm_expansion_factor=vssm_expansion_factor, ffn_expansion_factor = ffn_expansion_factor,
665
+ bias=bias, ssm=True, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[1])])
666
+
667
+ self.up2_1 = Upsample(int(dim*2**1)) ## From Level 2 to Level 1
668
+ self.clff_level1 = CLFF(int(dim*2**0), dim_n1=int(dim*2**1), dim_n2=(dim*2**2), bias=bias)
669
+ self.reduce_chan_level1 = nn.Conv2d(int(dim*2**1), int(dim*2**0), kernel_size=1, bias=bias)
670
+ self.decoder_level1 = nn.Sequential(*[Strip_VSSB(dim=int(dim*2**0), vssm_expansion_factor=vssm_expansion_factor, ffn_expansion_factor = ffn_expansion_factor,
671
+ bias=bias, ssm=True, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])])
672
+
673
+ # self.refinement = nn.Sequential(*[Strip_VSSB(dim=int(dim*2**0), expansion_factor=expansion_factor, bias=bias, ssm=True, LayerNorm_type=LayerNorm_type) for i in range(num_refinement_blocks)])
674
+
675
+ self.output = nn.Conv2d(int(dim*2**0), out_channels, kernel_size=3, stride=1, padding=1, bias=bias)
676
+
677
+ def forward(self, inp_img):
678
+
679
+ # Encoder
680
+ inp_enc_level1 = self.patch_embed(inp_img)
681
+ out_enc_level1 = self.encoder_level1(inp_enc_level1)
682
+ out_enc_level1_2 = F.interpolate(out_enc_level1, scale_factor=0.5) # dim*2, lvl1 down-scaled to lvl2
683
+
684
+ inp_enc_level2 = self.down1_2(out_enc_level1)
685
+ out_enc_level2 = self.encoder_level2(inp_enc_level2)
686
+ out_enc_level2_1 = F.interpolate(out_enc_level2, scale_factor=2) # dim*2, lvl2 up-scaled to lvl1
687
+
688
+ inp_enc_level3 = self.down2_3(out_enc_level2)
689
+ out_enc_level3 = self.encoder_level3(inp_enc_level3)
690
+ out_enc_level3_2 = F.interpolate(out_enc_level3, scale_factor=2) # dim*2**2, lvl3 up-scaled to lvl2 (lvl3->lvl2)
691
+ out_enc_level3_1 = F.interpolate(out_enc_level3_2, scale_factor=2) # dim*2**2, lvl3 up-scaled to lvl1 (lvl3->lvl2->lvl1)
692
+
693
+ out_enc_level1 = self.clff_level1(out_enc_level1, out_enc_level2_1, out_enc_level3_1)
694
+ out_enc_level2 = self.clff_level2(out_enc_level2, out_enc_level1_2, out_enc_level3_2)
695
+
696
+ # Decoder
697
+ out_dec_level3_decomp1 = self.decoder_level3(out_enc_level3)
698
+
699
+ inp_dec_level2_decomp1 = self.up3_2(out_dec_level3_decomp1)
700
+ inp_dec_level2_decomp1 = self.reduce_chan_level2(torch.cat((inp_dec_level2_decomp1, out_enc_level2), dim=1))
701
+ out_dec_level2_decomp1 = self.decoder_level2(inp_dec_level2_decomp1)
702
+
703
+ inp_dec_level1_decomp1 = self.up2_1(out_dec_level2_decomp1)
704
+ inp_dec_level1_decomp1 = self.reduce_chan_level1(torch.cat((inp_dec_level1_decomp1, out_enc_level1), dim=1))
705
+ out_dec_level1_decomp1 = self.decoder_level1(inp_dec_level1_decomp1)
706
+
707
+ out_dec_level1_decomp1 = self.output(out_dec_level1_decomp1)
708
+
709
+ out_dec_level1 = out_dec_level1_decomp1 + inp_img
710
+
711
+
712
+ return out_dec_level1, out_dec_level1_decomp1, None
713
+
714
+ def count_parameters(model):
715
+ total = sum(p.numel() for p in model.parameters())
716
+ trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
717
+ print(f"Total parameters: {total:,}")
718
+ print(f"Trainable parameters: {trainable:,}")
719
+
720
+
721
+ def main():
722
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
723
+ model = XYScanNet().to(device)
724
+
725
+ print("Model architecture:\n")
726
+ print(model)
727
+
728
+ count_parameters(model)
729
+
730
+ # Optionally test with a dummy input
731
+ dummy_input = torch.randn(1, 3, 256, 256).to(device)
732
+ output, _, _ = model(dummy_input)
733
+ print(f"Output shape: {output.shape}")
734
+
735
+
736
+ if __name__ == "__main__":
737
+ main()
models/XYScanNetP.py ADDED
@@ -0,0 +1,737 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numbers
2
+ import math
3
+ from typing import Optional
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from torch import Tensor
9
+
10
+ from einops import rearrange, repeat
11
+
12
+ from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn
13
+
14
+ try:
15
+ from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
16
+ except ImportError:
17
+ causal_conv1d_fn, causal_conv1d_update = None, None
18
+
19
+ try:
20
+ from mamba_ssm.ops.triton.selective_state_update import selective_state_update
21
+ except ImportError:
22
+ selective_state_update = None
23
+
24
+ try:
25
+ from mamba_ssm.ops.triton.layer_norm import RMSNorm, layer_norm_fn, rms_norm_fn
26
+ except ImportError:
27
+ RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None
28
+
29
+
30
+ def to_3d(x):
31
+ return rearrange(x, 'b c h w -> b (h w) c')
32
+
33
+
34
+ def to_4d(x, h, w):
35
+ return rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
36
+
37
+
38
+ class BiasFree_LayerNorm(nn.Module):
39
+ def __init__(self, normalized_shape):
40
+ super(BiasFree_LayerNorm, self).__init__()
41
+ if isinstance(normalized_shape, numbers.Integral):
42
+ normalized_shape = (normalized_shape,)
43
+ normalized_shape = torch.Size(normalized_shape)
44
+
45
+ assert len(normalized_shape) == 1
46
+
47
+ self.weight = nn.Parameter(torch.ones(normalized_shape))
48
+ self.normalized_shape = normalized_shape
49
+
50
+ def forward(self, x):
51
+ sigma = x.var(-1, keepdim=True, unbiased=False)
52
+ return x / torch.sqrt(sigma + 1e-5) * self.weight
53
+
54
+
55
+ class WithBias_LayerNorm(nn.Module):
56
+ def __init__(self, normalized_shape):
57
+ super(WithBias_LayerNorm, self).__init__()
58
+ if isinstance(normalized_shape, numbers.Integral):
59
+ normalized_shape = (normalized_shape,)
60
+ normalized_shape = torch.Size(normalized_shape)
61
+
62
+ assert len(normalized_shape) == 1
63
+
64
+ self.weight = nn.Parameter(torch.ones(normalized_shape))
65
+ self.bias = nn.Parameter(torch.zeros(normalized_shape))
66
+ self.normalized_shape = normalized_shape
67
+
68
+ def forward(self, x):
69
+ mu = x.mean(-1, keepdim=True)
70
+ sigma = x.var(-1, keepdim=True, unbiased=False)
71
+ return (x - mu) / torch.sqrt(sigma + 1e-5) * self.weight + self.bias
72
+
73
+
74
+ class LayerNorm(nn.Module):
75
+ def __init__(self, dim, LayerNorm_type):
76
+ super(LayerNorm, self).__init__()
77
+ if LayerNorm_type == 'BiasFree':
78
+ self.body = BiasFree_LayerNorm(dim)
79
+ else:
80
+ self.body = WithBias_LayerNorm(dim)
81
+
82
+ def forward(self, x):
83
+ h, w = x.shape[-2:]
84
+ return to_4d(self.body(to_3d(x)), h, w)
85
+
86
+ ##########################################################################
87
+ def conv(in_channels, out_channels, kernel_size, bias=False, stride = 1):
88
+ return nn.Conv2d(
89
+ in_channels, out_channels, kernel_size,
90
+ padding=(kernel_size//2), bias=bias, stride = stride)
91
+
92
+
93
+ """
94
+ Borrow from "https://github.com/state-spaces/mamba.git"
95
+ @article{mamba,
96
+ title={Mamba: Linear-Time Sequence Modeling with Selective State Spaces},
97
+ author={Gu, Albert and Dao, Tri},
98
+ journal={arXiv preprint arXiv:2312.00752},
99
+ year={2023}
100
+ }
101
+ """
102
+ class Mamba(nn.Module):
103
+ def __init__(
104
+ self,
105
+ d_model,
106
+ d_state=16,
107
+ d_conv=4,
108
+ expand=2,
109
+ dt_rank="auto",
110
+ dt_min=0.001,
111
+ dt_max=0.1,
112
+ dt_init="random",
113
+ dt_scale=1.0,
114
+ dt_init_floor=1e-4,
115
+ conv_bias=True,
116
+ bias=False,
117
+ use_fast_path=True, # Fused kernel options
118
+ layer_idx=None,
119
+ device=None,
120
+ dtype=None,
121
+ ):
122
+ factory_kwargs = {"device": device, "dtype": dtype}
123
+ super().__init__()
124
+ self.d_model = d_model
125
+ self.d_state = d_state
126
+ self.d_conv = d_conv
127
+ self.expand = expand
128
+ self.d_inner = int(self.expand * self.d_model)
129
+ self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank
130
+ self.use_fast_path = use_fast_path
131
+ self.layer_idx = layer_idx
132
+
133
+ self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=bias, **factory_kwargs)
134
+
135
+ self.conv1d = nn.Conv1d(
136
+ in_channels=self.d_inner,
137
+ out_channels=self.d_inner,
138
+ bias=conv_bias,
139
+ kernel_size=d_conv,
140
+ groups=self.d_inner,
141
+ padding=d_conv - 1,
142
+ **factory_kwargs,
143
+ )
144
+
145
+ self.activation = "silu"
146
+ self.act = nn.SiLU()
147
+
148
+ self.x_proj = nn.Linear(
149
+ self.d_inner, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs
150
+ )
151
+ self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True, **factory_kwargs)
152
+
153
+ # Initialize special dt projection to preserve variance at initialization
154
+ dt_init_std = self.dt_rank**-0.5 * dt_scale
155
+ if dt_init == "constant":
156
+ nn.init.constant_(self.dt_proj.weight, dt_init_std)
157
+ elif dt_init == "random":
158
+ nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std)
159
+ else:
160
+ raise NotImplementedError
161
+
162
+ # Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max
163
+ dt = torch.exp(
164
+ torch.rand(self.d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min))
165
+ + math.log(dt_min)
166
+ ).clamp(min=dt_init_floor)
167
+ # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
168
+ inv_dt = dt + torch.log(-torch.expm1(-dt))
169
+ with torch.no_grad():
170
+ self.dt_proj.bias.copy_(inv_dt)
171
+ # Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit
172
+ self.dt_proj.bias._no_reinit = True
173
+
174
+ # S4D real initialization
175
+ A = repeat(
176
+ torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device),
177
+ "n -> d n",
178
+ d=self.d_inner,
179
+ ).contiguous()
180
+ A_log = torch.log(A) # Keep A_log in fp32
181
+ self.A_log = nn.Parameter(A_log)
182
+ self.A_log._no_weight_decay = True
183
+
184
+ # D "skip" parameter
185
+ self.D = nn.Parameter(torch.ones(self.d_inner, device=device)) # Keep in fp32
186
+ self.D._no_weight_decay = True
187
+
188
+ self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs)
189
+
190
+ def forward(self, hidden_states, inference_params=None):
191
+ """
192
+ hidden_states: (B, L, D)
193
+ Returns: same shape as hidden_states
194
+ """
195
+ batch, seqlen, dim = hidden_states.shape
196
+
197
+ conv_state, ssm_state = None, None
198
+ if inference_params is not None:
199
+ conv_state, ssm_state = self._get_states_from_cache(inference_params, batch)
200
+ if inference_params.seqlen_offset > 0:
201
+ # The states are updated inplace
202
+ out, _, _ = self.step(hidden_states, conv_state, ssm_state)
203
+ return out
204
+
205
+ # We do matmul and transpose BLH -> HBL at the same time
206
+ xz = rearrange(
207
+ self.in_proj.weight @ rearrange(hidden_states, "b l d -> d (b l)"),
208
+ "d (b l) -> b d l",
209
+ l=seqlen,
210
+ )
211
+ if self.in_proj.bias is not None:
212
+ xz = xz + rearrange(self.in_proj.bias.to(dtype=xz.dtype), "d -> d 1")
213
+
214
+ A = -torch.exp(self.A_log.float()) # (d_inner, d_state)
215
+ # In the backward pass we write dx and dz next to each other to avoid torch.cat
216
+ if self.use_fast_path and causal_conv1d_fn is not None and inference_params is None: # Doesn't support outputting the states
217
+ out = mamba_inner_fn(
218
+ xz,
219
+ self.conv1d.weight,
220
+ self.conv1d.bias,
221
+ self.x_proj.weight,
222
+ self.dt_proj.weight,
223
+ self.out_proj.weight,
224
+ self.out_proj.bias,
225
+ A,
226
+ None, # input-dependent B
227
+ None, # input-dependent C
228
+ self.D.float(),
229
+ delta_bias=self.dt_proj.bias.float(),
230
+ delta_softplus=True,
231
+ )
232
+ else:
233
+ x, z = xz.chunk(2, dim=1)
234
+ # Compute short convolution
235
+ if conv_state is not None:
236
+ # If we just take x[:, :, -self.d_conv :], it will error if seqlen < self.d_conv
237
+ # Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise.
238
+ conv_state.copy_(F.pad(x, (self.d_conv - x.shape[-1], 0))) # Update state (B D W)
239
+ if causal_conv1d_fn is None:
240
+ x = self.act(self.conv1d(x)[..., :seqlen])
241
+ else:
242
+ assert self.activation in ["silu", "swish"]
243
+ x = causal_conv1d_fn(
244
+ x=x,
245
+ weight=rearrange(self.conv1d.weight, "d 1 w -> d w"),
246
+ bias=self.conv1d.bias,
247
+ activation=self.activation,
248
+ )
249
+
250
+ # We're careful here about the layout, to avoid extra transposes.
251
+ # We want dt to have d as the slowest moving dimension
252
+ # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
253
+ x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d")) # (bl d)
254
+ dt, B, C = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1)
255
+ dt = self.dt_proj.weight @ dt.t()
256
+ dt = rearrange(dt, "d (b l) -> b d l", l=seqlen)
257
+ B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
258
+ C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
259
+ assert self.activation in ["silu", "swish"]
260
+ y = selective_scan_fn(
261
+ x,
262
+ dt,
263
+ A,
264
+ B,
265
+ C,
266
+ self.D.float(),
267
+ z=z,
268
+ delta_bias=self.dt_proj.bias.float(),
269
+ delta_softplus=True,
270
+ return_last_state=ssm_state is not None,
271
+ )
272
+ if ssm_state is not None:
273
+ y, last_state = y
274
+ ssm_state.copy_(last_state)
275
+ y = rearrange(y, "b d l -> b l d")
276
+ out = self.out_proj(y)
277
+ return out
278
+
279
+ def step(self, hidden_states, conv_state, ssm_state):
280
+ dtype = hidden_states.dtype
281
+ assert hidden_states.shape[1] == 1, "Only support decoding with 1 token at a time for now"
282
+ xz = self.in_proj(hidden_states.squeeze(1)) # (B 2D)
283
+ x, z = xz.chunk(2, dim=-1) # (B D)
284
+
285
+ # Conv step
286
+ if causal_conv1d_update is None:
287
+ conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W)
288
+ conv_state[:, :, -1] = x
289
+ x = torch.sum(conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1) # (B D)
290
+ if self.conv1d.bias is not None:
291
+ x = x + self.conv1d.bias
292
+ x = self.act(x).to(dtype=dtype)
293
+ else:
294
+ x = causal_conv1d_update(
295
+ x,
296
+ conv_state,
297
+ rearrange(self.conv1d.weight, "d 1 w -> d w"),
298
+ self.conv1d.bias,
299
+ self.activation,
300
+ )
301
+
302
+ x_db = self.x_proj(x) # (B dt_rank+2*d_state)
303
+ dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1)
304
+ # Don't add dt_bias here
305
+ dt = F.linear(dt, self.dt_proj.weight) # (B d_inner)
306
+ A = -torch.exp(self.A_log.float()) # (d_inner, d_state)
307
+
308
+ # SSM step
309
+ if selective_state_update is None:
310
+ # Discretize A and B
311
+ dt = F.softplus(dt + self.dt_proj.bias.to(dtype=dt.dtype))
312
+ dA = torch.exp(torch.einsum("bd,dn->bdn", dt, A))
313
+ dB = torch.einsum("bd,bn->bdn", dt, B)
314
+ ssm_state.copy_(ssm_state * dA + rearrange(x, "b d -> b d 1") * dB)
315
+ y = torch.einsum("bdn,bn->bd", ssm_state.to(dtype), C)
316
+ y = y + self.D.to(dtype) * x
317
+ y = y * self.act(z) # (B D)
318
+ else:
319
+ y = selective_state_update(
320
+ ssm_state, x, dt, A, B, C, self.D, z=z, dt_bias=self.dt_proj.bias, dt_softplus=True
321
+ )
322
+
323
+ out = self.out_proj(y)
324
+ return out.unsqueeze(1), conv_state, ssm_state
325
+
326
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
327
+ device = self.out_proj.weight.device
328
+ conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype
329
+ conv_state = torch.zeros(
330
+ batch_size, self.d_model * self.expand, self.d_conv, device=device, dtype=conv_dtype
331
+ )
332
+ ssm_dtype = self.dt_proj.weight.dtype if dtype is None else dtype
333
+ # ssm_dtype = torch.float32
334
+ ssm_state = torch.zeros(
335
+ batch_size, self.d_model * self.expand, self.d_state, device=device, dtype=ssm_dtype
336
+ )
337
+ return conv_state, ssm_state
338
+
339
+ def _get_states_from_cache(self, inference_params, batch_size, initialize_states=False):
340
+ assert self.layer_idx is not None
341
+ if self.layer_idx not in inference_params.key_value_memory_dict:
342
+ batch_shape = (batch_size,)
343
+ conv_state = torch.zeros(
344
+ batch_size,
345
+ self.d_model * self.expand,
346
+ self.d_conv,
347
+ device=self.conv1d.weight.device,
348
+ dtype=self.conv1d.weight.dtype,
349
+ )
350
+ ssm_state = torch.zeros(
351
+ batch_size,
352
+ self.d_model * self.expand,
353
+ self.d_state,
354
+ device=self.dt_proj.weight.device,
355
+ dtype=self.dt_proj.weight.dtype,
356
+ # dtype=torch.float32,
357
+ )
358
+ inference_params.key_value_memory_dict[self.layer_idx] = (conv_state, ssm_state)
359
+ else:
360
+ conv_state, ssm_state = inference_params.key_value_memory_dict[self.layer_idx]
361
+ # TODO: What if batch size changes between generation, and we reuse the same states?
362
+ if initialize_states:
363
+ conv_state.zero_()
364
+ ssm_state.zero_()
365
+ return conv_state, ssm_state
366
+
367
+ ##########################################################################
368
+ ## Feed-forward Network
369
+ class FFN(nn.Module):
370
+ def __init__(self, dim, ffn_expansion_factor, bias):
371
+ super(FFN, self).__init__()
372
+
373
+ hidden_features = int(dim*ffn_expansion_factor)
374
+
375
+ self.project_in = nn.Conv2d(dim, hidden_features*2, kernel_size=1, bias=bias)
376
+
377
+ self.dwconv = nn.Conv2d(hidden_features*2, hidden_features*2, kernel_size=3, stride=1, padding=1, groups=hidden_features*2, bias=bias, dilation=1)
378
+
379
+ self.win_size = 8
380
+
381
+ self.modulator = nn.Parameter(torch.ones(self.win_size, self.win_size, dim*2)) # modulator
382
+
383
+ self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias)
384
+
385
+ def forward(self, x):
386
+ b, c, h, w = x.shape
387
+ h1, w1 = h//self.win_size, w//self.win_size
388
+ x = self.project_in(x)
389
+ x = self.dwconv(x)
390
+ x_win = rearrange(x, 'b c (wsh h1) (wsw w1) -> b h1 w1 wsh wsw c', wsh=self.win_size, wsw=self.win_size)
391
+ x_win = x_win * self.modulator
392
+ x = rearrange(x_win, 'b h1 w1 wsh wsw c -> b c (wsh h1) (wsw w1)', wsh=self.win_size, wsw=self.win_size, h1=h1, w1=w1)
393
+ x1, x2 = x.chunk(2, dim=1)
394
+ x = x1 * x2
395
+ x = self.project_out(x)
396
+ return x
397
+
398
+
399
+ ##########################################################################
400
+ ## Gated Depth-wise Feed-forward Network (GDFN)
401
+ class GDFN(nn.Module):
402
+ def __init__(self, dim, ffn_expansion_factor, bias):
403
+ super(GDFN, self).__init__()
404
+
405
+ hidden_features = int(dim*ffn_expansion_factor)
406
+
407
+ self.project_in = nn.Conv2d(dim, hidden_features*2, kernel_size=1, bias=bias)
408
+
409
+ self.dwconv = nn.Conv2d(hidden_features*2, hidden_features*2, kernel_size=3, stride=1, padding=1, groups=hidden_features*2, bias=bias, dilation=1)
410
+
411
+ self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias)
412
+
413
+ def forward(self, x):
414
+ x = self.project_in(x)
415
+ x = self.dwconv(x)
416
+ x1, x2 = x.chunk(2, dim=1)
417
+ x = F.silu(x1) * x2
418
+ x = self.project_out(x)
419
+ return x
420
+
421
+
422
+ ##########################################################################
423
+ ## Overlapped image patch embedding with 3x3 Conv
424
+ class OverlapPatchEmbed(nn.Module):
425
+ def __init__(self, in_c=3, embed_dim=48, bias=False):
426
+ super(OverlapPatchEmbed, self).__init__()
427
+
428
+ self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=3, stride=1, padding=1, bias=bias)
429
+
430
+ def forward(self, x):
431
+ x = self.proj(x)
432
+
433
+ return x
434
+
435
+
436
+ ##########################################################################
437
+ ## Resizing modules
438
+ class Downsample(nn.Module):
439
+ def __init__(self, n_feat):
440
+ super(Downsample, self).__init__()
441
+
442
+ self.body = nn.Sequential(nn.Upsample(scale_factor=0.5, mode='bilinear', align_corners=False),
443
+ nn.Conv2d(n_feat, n_feat * 2, 3, stride=1, padding=1, bias=False))
444
+
445
+ def forward(self, x):
446
+ return self.body(x)
447
+
448
+
449
+ class Upsample(nn.Module):
450
+ def __init__(self, n_feat):
451
+ super(Upsample, self).__init__()
452
+
453
+ self.body = nn.Sequential(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
454
+ nn.Conv2d(n_feat, n_feat // 2, 3, stride=1, padding=1, bias=False))
455
+
456
+ def forward(self, x):
457
+ return self.body(x)
458
+
459
+
460
+ """
461
+ Borrow from "https://github.com/pp00704831/Stripformer-ECCV-2022-.git"
462
+ @inproceedings{Tsai2022Stripformer,
463
+ author = {Fu-Jen Tsai and Yan-Tsung Peng and Yen-Yu Lin and Chung-Chi Tsai and Chia-Wen Lin},
464
+ title = {Stripformer: Strip Transformer for Fast Image Deblurring},
465
+ booktitle = {ECCV},
466
+ year = {2022}
467
+ }
468
+ """
469
+ class Intra_VSSM(nn.Module):
470
+ def __init__(self, dim, vssm_expansion_factor, bias): # gated = True
471
+ super(Intra_VSSM, self).__init__()
472
+ hidden = int(dim*vssm_expansion_factor)
473
+
474
+ self.proj_in = nn.Conv2d(dim, hidden*2, kernel_size=1, bias=bias)
475
+ self.dwconv = nn.Conv2d(hidden*2, hidden*2, kernel_size=3, stride=1, padding=1, groups=hidden*2, bias=bias)
476
+ self.proj_out = nn.Conv2d(hidden, dim, kernel_size=1, bias=bias)
477
+
478
+ self.conv_input = nn.Conv2d(hidden, hidden, kernel_size=1, padding=0, bias=bias)
479
+ self.fuse_out = nn.Conv2d(hidden, hidden, kernel_size=1, padding=0, bias=bias)
480
+ self.mamba = Mamba(d_model=hidden // 2)
481
+
482
+ def forward_core(self, x):
483
+ B, C, H, W = x.size()
484
+
485
+ x_input = torch.chunk(self.conv_input(x), 2, dim=1)
486
+
487
+ feature_h = (x_input[0]).permute(0, 2, 3, 1).contiguous()
488
+ feature_h = feature_h.view(B * H, W, C//2)
489
+
490
+ feature_v = (x_input[1]).permute(0, 3, 2, 1).contiguous()
491
+ feature_v = feature_v.view(B * W, H, C//2)
492
+
493
+ if H == W:
494
+ feature = torch.cat((feature_h, feature_v), dim=0) # B * H * 2, W, C//2
495
+ scan_output = self.mamba(feature)
496
+ scan_output = torch.chunk(scan_output, 2, dim=0)
497
+ scan_output_h = scan_output[0]
498
+ scan_output_v = scan_output[1]
499
+ else:
500
+ scan_output_h = self.mamba(feature_h)
501
+ scan_output_v = self.mamba(feature_v)
502
+
503
+ scan_output_h = scan_output_h.view(B, H, W, C//2).permute(0, 3, 1, 2).contiguous()
504
+ scan_output_v = scan_output_v.view(B, W, H, C//2).permute(0, 3, 2, 1).contiguous()
505
+ scan_output = self.fuse_out(torch.cat((scan_output_h, scan_output_v), dim=1))
506
+
507
+ return scan_output
508
+
509
+ def forward(self, x):
510
+ x = self.proj_in(x)
511
+ x, x_ = self.dwconv(x).chunk(2, dim=1)
512
+ x = self.forward_core(x)
513
+ x = F.silu(x_) * x
514
+ x = self.proj_out(x)
515
+ return x
516
+
517
+
518
+ class Inter_VSSM(nn.Module):
519
+ def __init__(self, dim, vssm_expansion_factor, bias): # gated = True
520
+ super(Inter_VSSM, self).__init__()
521
+ hidden = int(dim*vssm_expansion_factor)
522
+
523
+ self.proj_in = nn.Conv2d(dim, hidden*2, kernel_size=1, bias=bias)
524
+ self.dwconv = nn.Conv2d(hidden*2, hidden*2, kernel_size=3, stride=1, padding=1, groups=hidden*2, bias=bias)
525
+ self.proj_out = nn.Conv2d(hidden, dim, kernel_size=1, bias=bias)
526
+
527
+ self.avg_pool = nn.AdaptiveAvgPool2d((None,1))
528
+ self.conv_input = nn.Conv2d(hidden, hidden, kernel_size=1, padding=0, bias=bias)
529
+ self.fuse_out = nn.Conv2d(hidden, hidden, kernel_size=1, padding=0, bias=bias)
530
+ self.mamba = Mamba(d_model=hidden // 2)
531
+ self.sigmoid = nn.Sigmoid()
532
+
533
+ def forward_core(self, x):
534
+ B, C, H, W = x.size()
535
+
536
+ x_input = torch.chunk(self.conv_input(x), 2, dim=1) # B, C, H, W
537
+
538
+ feature_h = x_input[0].permute(0, 2, 1, 3).contiguous() # B, H, C//2, W
539
+ feature_h_score = self.avg_pool(feature_h) # B, H, C//2, 1
540
+ feature_h_score = feature_h_score.view(B, H, -1)
541
+
542
+ feature_v = x_input[1].permute(0, 3, 1, 2).contiguous() # B, W, C//2, H
543
+ feature_v_score = self.avg_pool(feature_v) # B, W, C//2, 1
544
+ feature_v_score = feature_v_score.view(B, W, -1)
545
+
546
+ if H == W:
547
+ feature_score = torch.cat((feature_h_score, feature_v_score), dim=0) # B * 2, W or H, C//2
548
+ scan_score = self.mamba(feature_score)
549
+ scan_score = torch.chunk(scan_score, 2, dim=0)
550
+ scan_score_h = scan_score[0]
551
+ scan_score_v = scan_score[1]
552
+ else:
553
+ scan_score_h = self.mamba(feature_h_score)
554
+ scan_score_v = self.mamba(feature_v_score)
555
+
556
+ scan_score_h = self.sigmoid(scan_score_h)
557
+ scan_score_v = self.sigmoid(scan_score_v)
558
+ feature_h = feature_h*scan_score_h[:,:,:,None]
559
+ feature_v = feature_v*scan_score_v[:,:,:,None]
560
+ feature_h = feature_h.view(B, H, C//2, W).permute(0, 2, 1, 3).contiguous()
561
+ feature_v = feature_v.view(B, W, C//2, H).permute(0, 2, 3, 1).contiguous()
562
+ output = self.fuse_out(torch.cat((feature_h, feature_v), dim=1))
563
+
564
+ return output
565
+
566
+ def forward(self, x):
567
+ x = self.proj_in(x)
568
+ x, x_ = self.dwconv(x).chunk(2, dim=1)
569
+ x = self.forward_core(x)
570
+ x = F.silu(x_) * x
571
+ x = self.proj_out(x)
572
+ return x
573
+
574
+
575
+ ##########################################################################
576
+ class Strip_VSSB(nn.Module):
577
+ def __init__(self, dim, vssm_expansion_factor, ffn_expansion_factor, bias=False, ssm=False, LayerNorm_type='WithBias'):
578
+ super(Strip_VSSB, self).__init__()
579
+ self.ssm = ssm
580
+ if self.ssm == True:
581
+ self.norm1_ssm = LayerNorm(dim, LayerNorm_type)
582
+ self.norm2_ssm = LayerNorm(dim, LayerNorm_type)
583
+ self.intra = Intra_VSSM(dim, vssm_expansion_factor, bias)
584
+ self.inter = Inter_VSSM(dim, vssm_expansion_factor, bias)
585
+ self.norm1_ffn = LayerNorm(dim, LayerNorm_type)
586
+ self.norm2_ffn = LayerNorm(dim, LayerNorm_type)
587
+ self.ffn1 = GDFN(dim, ffn_expansion_factor, bias)
588
+ self.ffn2 = GDFN(dim, ffn_expansion_factor, bias)
589
+
590
+ def forward(self, x):
591
+ if self.ssm == True:
592
+ x = x + self.intra(self.norm1_ssm(x))
593
+ x = x + self.ffn1(self.norm1_ffn(x))
594
+ if self.ssm == True:
595
+ x = x + self.inter(self.norm2_ssm(x))
596
+ x = x + self.ffn2(self.norm2_ffn(x))
597
+
598
+ return x
599
+
600
+
601
+ ##########################################################################
602
+ ##---------- Cross-level Feature Fusion by Adding Sigmoid(KL-Div) * Multi-Scale Feat -----------------------
603
+ class CLFF(nn.Module):
604
+ def __init__(self, dim, dim_n1, dim_n2, bias=False):
605
+ super(CLFF, self).__init__()
606
+
607
+ self.conv = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
608
+ self.conv_n1 = nn.Conv2d(dim_n1, dim, kernel_size=1, bias=bias)
609
+ self.conv_n2 = nn.Conv2d(dim_n2, dim, kernel_size=1, bias=bias)
610
+ self.fuse_out1 = nn.Conv2d(dim*2, dim, kernel_size=1, bias=bias)
611
+
612
+ self.log_sigmoid = nn.LogSigmoid()
613
+ self.sigmoid = nn.Sigmoid()
614
+
615
+ def forward(self, x, n1, n2):
616
+ x_ = self.conv(x)
617
+ n1_ = self.conv_n1(n1)
618
+ n2_ = self.conv_n2(n2)
619
+ kl_n1 = F.kl_div(input=self.log_sigmoid(n1_), target=self.log_sigmoid(x_), log_target=True)
620
+ kl_n2 = F.kl_div(input=self.log_sigmoid(n2_), target=self.log_sigmoid(x_), log_target=True)
621
+ #g = self.sigmoid(x_)
622
+ g1 = self.sigmoid(kl_n1)
623
+ g2 = self.sigmoid(kl_n2)
624
+ #x = (1 + g) * x_ + (1 - g) * (g1 * n1_ + g2 * n2_)
625
+ x = self.fuse_out1(torch.cat((x_, g1 * n1_ + g2 * n2_), dim=1))
626
+
627
+ return x
628
+
629
+ ##########################################################################
630
+ ##---------- StripScanNet -----------------------
631
+ class XYScanNetP(nn.Module):
632
+ def __init__(self,
633
+ inp_channels=3,
634
+ out_channels=3,
635
+ dim = 144, # 48, 72, 96, 120, 144
636
+ num_blocks = [3,3,6],
637
+ vssm_expansion_factor = 1, # 1 or 2
638
+ ffn_expansion_factor = 1, # 1 or 3
639
+ bias = False,
640
+ LayerNorm_type = 'WithBias', ## Other option 'BiasFree'
641
+ ):
642
+
643
+ super(XYScanNetP, self).__init__()
644
+
645
+ self.patch_embed = OverlapPatchEmbed(inp_channels, dim)
646
+
647
+ self.encoder_level1 = nn.Sequential(*[Strip_VSSB(dim=dim, vssm_expansion_factor=vssm_expansion_factor, ffn_expansion_factor = ffn_expansion_factor,
648
+ bias=bias, ssm=False, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])])
649
+
650
+ self.down1_2 = Downsample(dim) ## From Level 1 to Level 2
651
+ self.encoder_level2 = nn.Sequential(*[Strip_VSSB(dim=int(dim*2**1), vssm_expansion_factor=vssm_expansion_factor, ffn_expansion_factor = ffn_expansion_factor,
652
+ bias=bias, ssm=False, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[1])])
653
+
654
+ self.down2_3 = Downsample(int(dim*2**1)) ## From Level 2 to Level 3
655
+ self.encoder_level3 = nn.Sequential(*[Strip_VSSB(dim=int(dim*2**2), vssm_expansion_factor=vssm_expansion_factor, ffn_expansion_factor = ffn_expansion_factor,
656
+ bias=bias, ssm=False, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[2])])
657
+
658
+ self.decoder_level3 = nn.Sequential(*[Strip_VSSB(dim=int(dim*2**2), vssm_expansion_factor=vssm_expansion_factor, ffn_expansion_factor = ffn_expansion_factor,
659
+ bias=bias, ssm=True, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[2])])
660
+
661
+ self.up3_2 = Upsample(int(dim*2**2)) ## From Level 3 to Level 2
662
+ self.clff_level2 = CLFF(int(dim*2**1), dim_n1=int(dim*2**0), dim_n2=(dim*2**2), bias=bias)
663
+ self.reduce_chan_level2 = nn.Conv2d(int(dim*2**2), int(dim*2**1), kernel_size=1, bias=bias)
664
+ self.decoder_level2 = nn.Sequential(*[Strip_VSSB(dim=int(dim*2**1), vssm_expansion_factor=vssm_expansion_factor, ffn_expansion_factor = ffn_expansion_factor,
665
+ bias=bias, ssm=True, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[1])])
666
+
667
+ self.up2_1 = Upsample(int(dim*2**1)) ## From Level 2 to Level 1
668
+ self.clff_level1 = CLFF(int(dim*2**0), dim_n1=int(dim*2**1), dim_n2=(dim*2**2), bias=bias)
669
+ self.reduce_chan_level1 = nn.Conv2d(int(dim*2**1), int(dim*2**0), kernel_size=1, bias=bias)
670
+ self.decoder_level1 = nn.Sequential(*[Strip_VSSB(dim=int(dim*2**0), vssm_expansion_factor=vssm_expansion_factor, ffn_expansion_factor = ffn_expansion_factor,
671
+ bias=bias, ssm=True, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])])
672
+
673
+ # self.refinement = nn.Sequential(*[Strip_VSSB(dim=int(dim*2**0), expansion_factor=expansion_factor, bias=bias, ssm=True, LayerNorm_type=LayerNorm_type) for i in range(num_refinement_blocks)])
674
+
675
+ self.output = nn.Conv2d(int(dim*2**0), out_channels, kernel_size=3, stride=1, padding=1, bias=bias)
676
+
677
+ def forward(self, inp_img):
678
+
679
+ # Encoder
680
+ inp_enc_level1 = self.patch_embed(inp_img)
681
+ out_enc_level1 = self.encoder_level1(inp_enc_level1)
682
+ out_enc_level1_2 = F.interpolate(out_enc_level1, scale_factor=0.5) # dim*2, lvl1 down-scaled to lvl2
683
+
684
+ inp_enc_level2 = self.down1_2(out_enc_level1)
685
+ out_enc_level2 = self.encoder_level2(inp_enc_level2)
686
+ out_enc_level2_1 = F.interpolate(out_enc_level2, scale_factor=2) # dim*2, lvl2 up-scaled to lvl1
687
+
688
+ inp_enc_level3 = self.down2_3(out_enc_level2)
689
+ out_enc_level3 = self.encoder_level3(inp_enc_level3)
690
+ out_enc_level3_2 = F.interpolate(out_enc_level3, scale_factor=2) # dim*2**2, lvl3 up-scaled to lvl2 (lvl3->lvl2)
691
+ out_enc_level3_1 = F.interpolate(out_enc_level3_2, scale_factor=2) # dim*2**2, lvl3 up-scaled to lvl1 (lvl3->lvl2->lvl1)
692
+
693
+ out_enc_level1 = self.clff_level1(out_enc_level1, out_enc_level2_1, out_enc_level3_1)
694
+ out_enc_level2 = self.clff_level2(out_enc_level2, out_enc_level1_2, out_enc_level3_2)
695
+
696
+ # Decoder
697
+ out_dec_level3_decomp1 = self.decoder_level3(out_enc_level3)
698
+
699
+ inp_dec_level2_decomp1 = self.up3_2(out_dec_level3_decomp1)
700
+ inp_dec_level2_decomp1 = self.reduce_chan_level2(torch.cat((inp_dec_level2_decomp1, out_enc_level2), dim=1))
701
+ out_dec_level2_decomp1 = self.decoder_level2(inp_dec_level2_decomp1)
702
+
703
+ inp_dec_level1_decomp1 = self.up2_1(out_dec_level2_decomp1)
704
+ inp_dec_level1_decomp1 = self.reduce_chan_level1(torch.cat((inp_dec_level1_decomp1, out_enc_level1), dim=1))
705
+ out_dec_level1_decomp1 = self.decoder_level1(inp_dec_level1_decomp1)
706
+
707
+ out_dec_level1_decomp1 = self.output(out_dec_level1_decomp1)
708
+
709
+ out_dec_level1 = out_dec_level1_decomp1 + inp_img
710
+
711
+
712
+ return out_dec_level1, out_dec_level1_decomp1, None
713
+
714
+ def count_parameters(model):
715
+ total = sum(p.numel() for p in model.parameters())
716
+ trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
717
+ print(f"Total parameters: {total:,}")
718
+ print(f"Trainable parameters: {trainable:,}")
719
+
720
+
721
+ def main():
722
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
723
+ model = XYScanNetP().to(device)
724
+
725
+ print("Model architecture:\n")
726
+ print(model)
727
+
728
+ count_parameters(model)
729
+
730
+ # Optionally test with a dummy input
731
+ dummy_input = torch.randn(1, 3, 256, 256).to(device)
732
+ output, _, _ = model(dummy_input)
733
+ print(f"Output shape: {output.shape}")
734
+
735
+
736
+ if __name__ == "__main__":
737
+ main()
models/__init__.py ADDED
File without changes
models/__pycache__/XYScanNet.cpython-38.pyc ADDED
Binary file (21.2 kB). View file
 
models/__pycache__/XYScanNetP.cpython-38.pyc ADDED
Binary file (21.2 kB). View file
 
models/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (140 Bytes). View file
 
models/__pycache__/networks.cpython-38.pyc ADDED
Binary file (691 Bytes). View file
 
models/losses.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torchvision.models as models
4
+ import torch.nn.functional as F
5
+
6
+ class Vgg19(torch.nn.Module):
7
+ def __init__(self, requires_grad=False):
8
+ super(Vgg19, self).__init__()
9
+ vgg_pretrained_features = vgg19(pretrained=True).features
10
+ self.slice1 = torch.nn.Sequential()
11
+
12
+ for x in range(12):
13
+ self.slice1.add_module(str(x), vgg_pretrained_features[x].eval())
14
+
15
+ if not requires_grad:
16
+ for param in self.parameters():
17
+ param.requires_grad = False
18
+
19
+ def forward(self, X):
20
+ h_relu1 = self.slice1(X)
21
+ return h_relu1
22
+
23
+ class ContrastLoss(nn.Module):
24
+ def __init__(self, ablation=False):
25
+
26
+ super(ContrastLoss, self).__init__()
27
+ self.vgg = Vgg19().cuda()
28
+ self.l1 = nn.L1Loss()
29
+ self.ab = ablation
30
+ self.down_sample_4 = nn.Upsample(scale_factor=1 / 4, mode='bilinear')
31
+ def forward(self, restore, sharp, blur):
32
+ B, C, H, W = restore.size()
33
+ restore_vgg, sharp_vgg, blur_vgg = self.vgg(restore), self.vgg(sharp), self.vgg(blur)
34
+
35
+ # filter out sharp regions
36
+ threshold = 0.01
37
+ mask = torch.mean(torch.abs(sharp-blur), dim=1).view(B, 1, H, W)
38
+ mask[mask <= threshold] = 0
39
+ mask[mask > threshold] = 1
40
+ mask = self.down_sample_4(mask)
41
+ d_ap = torch.mean(torch.abs((restore_vgg - sharp_vgg.detach())), dim=1).view(B, 1, H//4, W//4)
42
+ d_an = torch.mean(torch.abs((restore_vgg - blur_vgg.detach())), dim=1).view(B, 1, H//4, W//4)
43
+ mask_size = torch.sum(mask)
44
+ contrastive = torch.sum((d_ap / (d_an + 1e-7)) * mask) / mask_size
45
+
46
+ return contrastive
47
+
48
+
49
+ class ContrastLoss_Ori(nn.Module):
50
+ def __init__(self, ablation=False):
51
+ super(ContrastLoss_Ori, self).__init__()
52
+ self.vgg = Vgg19().cuda()
53
+ self.l1 = nn.L1Loss()
54
+ self.ab = ablation
55
+
56
+ def forward(self, restore, sharp, blur):
57
+
58
+ restore_vgg, sharp_vgg, blur_vgg = self.vgg(restore), self.vgg(sharp), self.vgg(blur)
59
+ d_ap = self.l1(restore_vgg, sharp_vgg.detach())
60
+ d_an = self.l1(restore_vgg, blur_vgg.detach())
61
+ contrastive_loss = d_ap / (d_an + 1e-7)
62
+
63
+ return contrastive_loss
64
+
65
+
66
+ class CharbonnierLoss(nn.Module):
67
+ """Charbonnier Loss (L1)"""
68
+
69
+ def __init__(self, eps=1e-3):
70
+ super(CharbonnierLoss, self).__init__()
71
+ self.eps = eps
72
+
73
+ def forward(self, x, y):
74
+ diff = x - y
75
+ # loss = torch.sum(torch.sqrt(diff * diff + self.eps))
76
+ loss = torch.mean(torch.sqrt((diff * diff) + (self.eps * self.eps)))
77
+ return loss
78
+
79
+
80
+ class EdgeLoss(nn.Module):
81
+ def __init__(self):
82
+ super(EdgeLoss, self).__init__()
83
+ k = torch.Tensor([[.05, .25, .4, .25, .05]])
84
+ self.kernel = torch.matmul(k.t(), k).unsqueeze(0).repeat(3, 1, 1, 1)
85
+ if torch.cuda.is_available():
86
+ self.kernel = self.kernel.cuda()
87
+ self.loss = CharbonnierLoss()
88
+
89
+ def conv_gauss(self, img):
90
+ n_channels, _, kw, kh = self.kernel.shape
91
+ img = F.pad(img, (kw // 2, kh // 2, kw // 2, kh // 2), mode='replicate')
92
+ return F.conv2d(img, self.kernel, groups=n_channels)
93
+
94
+ def laplacian_kernel(self, current):
95
+ filtered = self.conv_gauss(current) # filter
96
+ down = filtered[:, :, ::2, ::2] # downsample
97
+ new_filter = torch.zeros_like(filtered)
98
+ new_filter[:, :, ::2, ::2] = down * 4 # upsample
99
+ filtered = self.conv_gauss(new_filter) # filter
100
+ diff = current - filtered
101
+ return diff
102
+
103
+ def forward(self, x, y):
104
+ # x = torch.clamp(x + 0.5, min = 0,max = 1)
105
+ # y = torch.clamp(y + 0.5, min = 0,max = 1)
106
+ loss = self.loss(self.laplacian_kernel(x), self.laplacian_kernel(y))
107
+ return loss
108
+
109
+
110
+ class Stripformer_Loss(nn.Module):
111
+
112
+ def __init__(self, ):
113
+ super(Stripformer_Loss, self).__init__()
114
+
115
+ self.char = CharbonnierLoss()
116
+ self.edge = EdgeLoss()
117
+ self.contrastive = ContrastLoss()
118
+
119
+ def forward(self, restore, sharp, blur):
120
+ char = self.char(restore, sharp)
121
+ edge = 0.05 * self.edge(restore, sharp)
122
+ contrastive = 0.0005 * self.contrastive(restore, sharp, blur)
123
+ loss = char + edge + contrastive
124
+ return loss
125
+
126
+
127
+ def get_loss(model):
128
+ if model['content_loss'] == 'Stripformer_Loss':
129
+ content_loss = Stripformer_Loss()
130
+ elif model['content_loss'] == 'CharbonnierLoss':
131
+ content_loss = CharbonnierLoss()
132
+ else:
133
+ raise ValueError("ContentLoss [%s] not recognized." % model['content_loss'])
134
+ return content_loss
135
+
136
+ from typing import Union, List, Dict, Any, cast
137
+
138
+ import torch
139
+ import torch.nn as nn
140
+
141
+ class VGG(nn.Module):
142
+ def __init__(
143
+ self, features: nn.Module, num_classes: int = 1000, init_weights: bool = True, dropout: float = 0.5
144
+ ) -> None:
145
+ super().__init__()
146
+ self.features = features
147
+ self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
148
+ self.classifier = nn.Sequential(
149
+ nn.Linear(512 * 7 * 7, 4096),
150
+ nn.ReLU(True),
151
+ nn.Dropout(p=dropout),
152
+ nn.Linear(4096, 4096),
153
+ nn.ReLU(True),
154
+ nn.Dropout(p=dropout),
155
+ nn.Linear(4096, num_classes),
156
+ )
157
+ if init_weights:
158
+ for m in self.modules():
159
+ if isinstance(m, nn.Conv2d):
160
+ nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
161
+ if m.bias is not None:
162
+ nn.init.constant_(m.bias, 0)
163
+ elif isinstance(m, nn.BatchNorm2d):
164
+ nn.init.constant_(m.weight, 1)
165
+ nn.init.constant_(m.bias, 0)
166
+ elif isinstance(m, nn.Linear):
167
+ nn.init.normal_(m.weight, 0, 0.01)
168
+ nn.init.constant_(m.bias, 0)
169
+
170
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
171
+ x = self.features(x)
172
+ x = self.avgpool(x)
173
+ x = torch.flatten(x, 1)
174
+ x = self.classifier(x)
175
+ return x
176
+
177
+
178
+ def make_layers(cfg: List[Union[str, int]], batch_norm: bool = False) -> nn.Sequential:
179
+ layers: List[nn.Module] = []
180
+ in_channels = 3
181
+ for v in cfg:
182
+ if v == "M":
183
+ layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
184
+ else:
185
+ v = cast(int, v)
186
+ conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
187
+ if batch_norm:
188
+ layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
189
+ else:
190
+ layers += [conv2d, nn.ReLU(inplace=True)]
191
+ in_channels = v
192
+ return nn.Sequential(*layers)
193
+
194
+
195
+ cfgs: Dict[str, List[Union[str, int]]] = {
196
+ "A": [64, "M", 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"],
197
+ "B": [64, 64, "M", 128, 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"],
198
+ "D": [64, 64, "M", 128, 128, "M", 256, 256, 256, "M", 512, 512, 512, "M", 512, 512, 512, "M"],
199
+ "E": [64, 64, "M", 128, 128, "M", 256, 256, 256, 256, "M", 512, 512, 512, 512, "M", 512, 512, 512, 512, "M"],
200
+ }
201
+
202
+ def _vgg(arch: str, cfg: str, batch_norm: bool, pretrained: bool, progress: bool, **kwargs: Any) -> VGG:
203
+ if pretrained:
204
+ kwargs["init_weights"] = False
205
+ model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm), **kwargs)
206
+ if pretrained:
207
+ state_dict = torch.load("/home/hanzhou1996/low-level/StripMamba/models/vgg19-dcbb9e9d.pth") # change the path to vgg19.pth
208
+ model.load_state_dict(state_dict)
209
+ return model
210
+
211
+
212
+ def vgg19(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG:
213
+ r"""VGG 19-layer model (configuration "E")
214
+ `"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_.
215
+ The required minimum input size of the model is 32x32.
216
+
217
+ Args:
218
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
219
+ progress (bool): If True, displays a progress bar of the download to stderr
220
+ """
221
+ return _vgg("vgg19", "E", False, pretrained, progress, **kwargs)
222
+ """
223
+ if __name__ == "__main__":
224
+ device = "cuda" if torch.cuda.is_available() else "cpu"
225
+ #model = VGG(make_layers(cfgs["E"], batch_norm=False)).to(device)
226
+ #model.load_state_dict(torch.load("models/vgg19-dcbb9e9d.pth"))
227
+ model = vgg19().to(device)
228
+ print(model.features)
229
+ BATCH_SIZE = 3
230
+ x = torch.randn(3, 3, 224, 224).to(device)
231
+ assert model(x).shape == torch.Size([BATCH_SIZE, 1000])
232
+ print(model(x).shape)
233
+ """
models/models.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch.nn as nn
3
+ #from skimage.measure import compare_ssim as SSIM
4
+ from skimage.metrics import structural_similarity as SSIM
5
+
6
+ from util.metrics import PSNR
7
+
8
+
9
+ class DeblurModel(nn.Module):
10
+ def __init__(self):
11
+ super(DeblurModel, self).__init__()
12
+
13
+ def get_input(self, data):
14
+ img = data['a']
15
+ inputs = img
16
+ targets = data['b']
17
+ inputs, targets = inputs.cuda(), targets.cuda()
18
+ return inputs, targets
19
+
20
+ def tensor2im(self, image_tensor, imtype=np.uint8):
21
+ image_numpy = image_tensor[0].cpu().float().numpy()
22
+ image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 0.5) * 255.0
23
+ return image_numpy
24
+
25
+ def get_images_and_metrics(self, inp, output, target) -> (float, float, np.ndarray):
26
+ inp = self.tensor2im(inp)
27
+ fake = self.tensor2im(output.data)
28
+ real = self.tensor2im(target.data)
29
+ psnr = PSNR(fake, real)
30
+ ssim = SSIM(fake.astype('uint8'), real.astype('uint8'), channel_axis=2)
31
+ vis_img = np.hstack((inp, fake, real))
32
+ return psnr, ssim, vis_img
33
+
34
+
35
+ def get_model(model_config):
36
+ return DeblurModel()
models/networks.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ from models.XYScanNet import XYScanNet
3
+ from models.XYScanNetP import XYScanNetP
4
+
5
+ def get_generator(model_config):
6
+ generator_name = model_config['g_name']
7
+ if generator_name == 'XYScanNet':
8
+ model_g = XYScanNet()
9
+ elif generator_name == 'XYScanNetP':
10
+ model_g = XYScanNetP()
11
+ else:
12
+ raise ValueError("Generator Network [%s] not recognized." % generator_name)
13
+ return nn.DataParallel(model_g)
14
+
15
+ def get_nets(model_config):
16
+ return get_generator(model_config)
models/sota/FFTformer.py ADDED
@@ -0,0 +1,324 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import numbers
5
+ from einops import rearrange
6
+
7
+
8
+ def to_3d(x):
9
+ return rearrange(x, 'b c h w -> b (h w) c')
10
+
11
+
12
+ def to_4d(x, h, w):
13
+ return rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
14
+
15
+
16
+ class BiasFree_LayerNorm(nn.Module):
17
+ def __init__(self, normalized_shape):
18
+ super(BiasFree_LayerNorm, self).__init__()
19
+ if isinstance(normalized_shape, numbers.Integral):
20
+ normalized_shape = (normalized_shape,)
21
+ normalized_shape = torch.Size(normalized_shape)
22
+
23
+ assert len(normalized_shape) == 1
24
+
25
+ self.weight = nn.Parameter(torch.ones(normalized_shape))
26
+ self.normalized_shape = normalized_shape
27
+
28
+ def forward(self, x):
29
+ sigma = x.var(-1, keepdim=True, unbiased=False)
30
+ return x / torch.sqrt(sigma + 1e-5) * self.weight
31
+
32
+
33
+ class WithBias_LayerNorm(nn.Module):
34
+ def __init__(self, normalized_shape):
35
+ super(WithBias_LayerNorm, self).__init__()
36
+ if isinstance(normalized_shape, numbers.Integral):
37
+ normalized_shape = (normalized_shape,)
38
+ normalized_shape = torch.Size(normalized_shape)
39
+
40
+ assert len(normalized_shape) == 1
41
+
42
+ self.weight = nn.Parameter(torch.ones(normalized_shape))
43
+ self.bias = nn.Parameter(torch.zeros(normalized_shape))
44
+ self.normalized_shape = normalized_shape
45
+
46
+ def forward(self, x):
47
+ mu = x.mean(-1, keepdim=True)
48
+ sigma = x.var(-1, keepdim=True, unbiased=False)
49
+ return (x - mu) / torch.sqrt(sigma + 1e-5) * self.weight + self.bias
50
+
51
+
52
+ class LayerNorm(nn.Module):
53
+ def __init__(self, dim, LayerNorm_type):
54
+ super(LayerNorm, self).__init__()
55
+ if LayerNorm_type == 'BiasFree':
56
+ self.body = BiasFree_LayerNorm(dim)
57
+ else:
58
+ self.body = WithBias_LayerNorm(dim)
59
+
60
+ def forward(self, x):
61
+ h, w = x.shape[-2:]
62
+ return to_4d(self.body(to_3d(x)), h, w)
63
+
64
+
65
+ class DFFN(nn.Module):
66
+ def __init__(self, dim, ffn_expansion_factor, bias):
67
+
68
+ super(DFFN, self).__init__()
69
+
70
+ hidden_features = int(dim * ffn_expansion_factor)
71
+
72
+ self.patch_size = 8
73
+
74
+ self.dim = dim
75
+ self.project_in = nn.Conv2d(dim, hidden_features * 2, kernel_size=1, bias=bias)
76
+
77
+ self.dwconv = nn.Conv2d(hidden_features * 2, hidden_features * 2, kernel_size=3, stride=1, padding=1,
78
+ groups=hidden_features * 2, bias=bias)
79
+
80
+ self.fft = nn.Parameter(torch.ones((hidden_features * 2, 1, 1, self.patch_size, self.patch_size // 2 + 1)))
81
+ self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias)
82
+
83
+ def forward(self, x):
84
+ x = self.project_in(x)
85
+ x_patch = rearrange(x, 'b c (h patch1) (w patch2) -> b c h w patch1 patch2', patch1=self.patch_size,
86
+ patch2=self.patch_size)
87
+ x_patch_fft = torch.fft.rfft2(x_patch.float())
88
+ x_patch_fft = x_patch_fft * self.fft
89
+ x_patch = torch.fft.irfft2(x_patch_fft, s=(self.patch_size, self.patch_size))
90
+ x = rearrange(x_patch, 'b c h w patch1 patch2 -> b c (h patch1) (w patch2)', patch1=self.patch_size,
91
+ patch2=self.patch_size)
92
+ x1, x2 = self.dwconv(x).chunk(2, dim=1)
93
+
94
+ x = F.gelu(x1) * x2
95
+ x = self.project_out(x)
96
+ return x
97
+
98
+
99
+ class FSAS(nn.Module):
100
+ def __init__(self, dim, bias):
101
+ super(FSAS, self).__init__()
102
+
103
+ self.to_hidden = nn.Conv2d(dim, dim * 6, kernel_size=1, bias=bias)
104
+ self.to_hidden_dw = nn.Conv2d(dim * 6, dim * 6, kernel_size=3, stride=1, padding=1, groups=dim * 6, bias=bias)
105
+
106
+ self.project_out = nn.Conv2d(dim * 2, dim, kernel_size=1, bias=bias)
107
+
108
+ self.norm = LayerNorm(dim * 2, LayerNorm_type='WithBias')
109
+
110
+ self.patch_size = 8
111
+
112
+ def forward(self, x):
113
+ hidden = self.to_hidden(x)
114
+
115
+ q, k, v = self.to_hidden_dw(hidden).chunk(3, dim=1)
116
+
117
+ q_patch = rearrange(q, 'b c (h patch1) (w patch2) -> b c h w patch1 patch2', patch1=self.patch_size,
118
+ patch2=self.patch_size)
119
+ k_patch = rearrange(k, 'b c (h patch1) (w patch2) -> b c h w patch1 patch2', patch1=self.patch_size,
120
+ patch2=self.patch_size)
121
+ q_fft = torch.fft.rfft2(q_patch.float())
122
+ k_fft = torch.fft.rfft2(k_patch.float())
123
+
124
+ out = q_fft * k_fft
125
+ out = torch.fft.irfft2(out, s=(self.patch_size, self.patch_size))
126
+ out = rearrange(out, 'b c h w patch1 patch2 -> b c (h patch1) (w patch2)', patch1=self.patch_size,
127
+ patch2=self.patch_size)
128
+
129
+ out = self.norm(out)
130
+
131
+ output = v * out
132
+ output = self.project_out(output)
133
+
134
+ return output
135
+
136
+
137
+ ##########################################################################
138
+ class TransformerBlock(nn.Module):
139
+ def __init__(self, dim, ffn_expansion_factor=2.66, bias=False, LayerNorm_type='WithBias', att=False):
140
+ super(TransformerBlock, self).__init__()
141
+
142
+ self.att = att
143
+ if self.att:
144
+ self.norm1 = LayerNorm(dim, LayerNorm_type)
145
+ self.attn = FSAS(dim, bias)
146
+
147
+ self.norm2 = LayerNorm(dim, LayerNorm_type)
148
+ self.ffn = DFFN(dim, ffn_expansion_factor, bias)
149
+
150
+ def forward(self, x):
151
+ if self.att:
152
+ x = x + self.attn(self.norm1(x))
153
+
154
+ x = x + self.ffn(self.norm2(x))
155
+
156
+ return x
157
+
158
+
159
+ class Fuse(nn.Module):
160
+ def __init__(self, n_feat):
161
+ super(Fuse, self).__init__()
162
+ self.n_feat = n_feat
163
+ self.att_channel = TransformerBlock(dim=n_feat * 2)
164
+
165
+ self.conv = nn.Conv2d(n_feat * 2, n_feat * 2, 1, 1, 0)
166
+ self.conv2 = nn.Conv2d(n_feat * 2, n_feat * 2, 1, 1, 0)
167
+
168
+ def forward(self, enc, dnc):
169
+ x = self.conv(torch.cat((enc, dnc), dim=1))
170
+ x = self.att_channel(x)
171
+ x = self.conv2(x)
172
+ e, d = torch.split(x, [self.n_feat, self.n_feat], dim=1)
173
+ output = e + d
174
+
175
+ return output
176
+
177
+
178
+ ##########################################################################
179
+ ## Overlapped image patch embedding with 3x3 Conv
180
+ class OverlapPatchEmbed(nn.Module):
181
+ def __init__(self, in_c=3, embed_dim=48, bias=False):
182
+ super(OverlapPatchEmbed, self).__init__()
183
+
184
+ self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=3, stride=1, padding=1, bias=bias)
185
+
186
+ def forward(self, x):
187
+ x = self.proj(x)
188
+
189
+ return x
190
+
191
+
192
+ ##########################################################################
193
+ ## Resizing modules
194
+ class Downsample(nn.Module):
195
+ def __init__(self, n_feat):
196
+ super(Downsample, self).__init__()
197
+
198
+ self.body = nn.Sequential(nn.Upsample(scale_factor=0.5, mode='bilinear', align_corners=False),
199
+ nn.Conv2d(n_feat, n_feat * 2, 3, stride=1, padding=1, bias=False))
200
+
201
+ def forward(self, x):
202
+ return self.body(x)
203
+
204
+
205
+ class Upsample(nn.Module):
206
+ def __init__(self, n_feat):
207
+ super(Upsample, self).__init__()
208
+
209
+ self.body = nn.Sequential(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
210
+ nn.Conv2d(n_feat, n_feat // 2, 3, stride=1, padding=1, bias=False))
211
+
212
+ def forward(self, x):
213
+ return self.body(x)
214
+
215
+
216
+ ##########################################################################
217
+ ##---------- FFTformer -----------------------
218
+ class fftformer(nn.Module):
219
+ def __init__(self,
220
+ inp_channels=3,
221
+ out_channels=3,
222
+ dim=8,
223
+ num_blocks=[6, 6, 12, 8],
224
+ num_refinement_blocks=4,
225
+ ffn_expansion_factor=3,
226
+ bias=False,
227
+ ):
228
+ super(fftformer, self).__init__()
229
+
230
+ self.patch_embed = OverlapPatchEmbed(inp_channels, dim)
231
+
232
+ self.encoder_level1 = nn.Sequential(*[
233
+ TransformerBlock(dim=dim, ffn_expansion_factor=ffn_expansion_factor, bias=bias) for i in
234
+ range(num_blocks[0])])
235
+
236
+ self.down1_2 = Downsample(dim)
237
+ self.encoder_level2 = nn.Sequential(*[
238
+ TransformerBlock(dim=int(dim * 2 ** 1), ffn_expansion_factor=ffn_expansion_factor,
239
+ bias=bias) for i in range(num_blocks[1])])
240
+
241
+ self.down2_3 = Downsample(int(dim * 2 ** 1))
242
+ self.encoder_level3 = nn.Sequential(*[
243
+ TransformerBlock(dim=int(dim * 2 ** 2), ffn_expansion_factor=ffn_expansion_factor,
244
+ bias=bias) for i in range(num_blocks[2])])
245
+
246
+ self.decoder_level3 = nn.Sequential(*[
247
+ TransformerBlock(dim=int(dim * 2 ** 2), ffn_expansion_factor=ffn_expansion_factor,
248
+ bias=bias, att=True) for i in range(num_blocks[2])])
249
+
250
+ self.up3_2 = Upsample(int(dim * 2 ** 2))
251
+ self.reduce_chan_level2 = nn.Conv2d(int(dim * 2 ** 2), int(dim * 2 ** 1), kernel_size=1, bias=bias)
252
+ self.decoder_level2 = nn.Sequential(*[
253
+ TransformerBlock(dim=int(dim * 2 ** 1), ffn_expansion_factor=ffn_expansion_factor,
254
+ bias=bias, att=True) for i in range(num_blocks[1])])
255
+
256
+ self.up2_1 = Upsample(int(dim * 2 ** 1))
257
+
258
+ self.decoder_level1 = nn.Sequential(*[
259
+ TransformerBlock(dim=int(dim), ffn_expansion_factor=ffn_expansion_factor,
260
+ bias=bias, att=True) for i in range(num_blocks[0])])
261
+
262
+ self.refinement = nn.Sequential(*[
263
+ TransformerBlock(dim=int(dim), ffn_expansion_factor=ffn_expansion_factor,
264
+ bias=bias, att=True) for i in range(num_refinement_blocks)])
265
+
266
+ self.fuse2 = Fuse(dim * 2)
267
+ self.fuse1 = Fuse(dim)
268
+ self.output = nn.Conv2d(int(dim), out_channels, kernel_size=3, stride=1, padding=1, bias=bias)
269
+
270
+ def forward(self, inp_img):
271
+ inp_enc_level1 = self.patch_embed(inp_img)
272
+ out_enc_level1 = self.encoder_level1(inp_enc_level1)
273
+
274
+ inp_enc_level2 = self.down1_2(out_enc_level1)
275
+ out_enc_level2 = self.encoder_level2(inp_enc_level2)
276
+
277
+ inp_enc_level3 = self.down2_3(out_enc_level2)
278
+ out_enc_level3 = self.encoder_level3(inp_enc_level3)
279
+
280
+ out_dec_level3 = self.decoder_level3(out_enc_level3)
281
+
282
+ inp_dec_level2 = self.up3_2(out_dec_level3)
283
+
284
+ inp_dec_level2 = self.fuse2(inp_dec_level2, out_enc_level2)
285
+
286
+ out_dec_level2 = self.decoder_level2(inp_dec_level2)
287
+
288
+ inp_dec_level1 = self.up2_1(out_dec_level2)
289
+
290
+ inp_dec_level1 = self.fuse1(inp_dec_level1, out_enc_level1)
291
+ out_dec_level1 = self.decoder_level1(inp_dec_level1)
292
+
293
+ out_dec_level1 = self.refinement(out_dec_level1)
294
+
295
+ out_dec_level1 = self.output(out_dec_level1) + inp_img
296
+
297
+ return out_dec_level1
298
+
299
+ #"""
300
+ import time
301
+ start_time = time.time()
302
+ inp = torch.randn(1, 3, 512, 512).cuda()#.to(dtype=torch.float16)
303
+ model = fftformer().cuda()#.to(dtype=torch.float16)
304
+ out = model(inp)
305
+ print(out.shape)
306
+ print("--- %s seconds ---" % (time.time() - start_time))
307
+ pytorch_total_params = sum(p.numel() for p in model.parameters())
308
+ print("--- {num} parameters ---".format(num = pytorch_total_params))
309
+ pytorch_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
310
+ print("--- {num} trainable parameters ---".format(num = pytorch_trainable_params))
311
+ gpu_memmem_usage_bytes = torch.cuda.max_memory_allocated()
312
+ print(gpu_memmem_usage_bytes / 1024 / 1024 / 1024) # 64: 1.32 128: 4.94 256: 19.12; 512: OOM
313
+ #"""
314
+ """
315
+ import torch
316
+ from ptflops import get_model_complexity_info
317
+
318
+ with torch.cuda.device(0):
319
+ net = model
320
+ macs, params = get_model_complexity_info(net, (3, 256, 256), as_strings=True,
321
+ print_per_layer_stat=True, verbose=True)
322
+ print('{:<30} {:<8}'.format('Computational complexity: ', macs)) # 31.97 GMac
323
+ print('{:<30} {:<8}'.format('Number of parameters: ', params)) # 8.37 M
324
+ """
models/sota/Restormer.py ADDED
@@ -0,0 +1,340 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Restormer: Efficient Transformer for High-Resolution Image Restoration
2
+ ## Syed Waqas Zamir, Aditya Arora, Salman Khan, Munawar Hayat, Fahad Shahbaz Khan, and Ming-Hsuan Yang
3
+ ## https://arxiv.org/abs/2111.09881
4
+
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from pdb import set_trace as stx
10
+ import numbers
11
+
12
+ from einops import rearrange
13
+
14
+
15
+
16
+ ##########################################################################
17
+ ## Layer Norm
18
+
19
+ def to_3d(x):
20
+ return rearrange(x, 'b c h w -> b (h w) c')
21
+
22
+ def to_4d(x,h,w):
23
+ return rearrange(x, 'b (h w) c -> b c h w',h=h,w=w)
24
+
25
+ class BiasFree_LayerNorm(nn.Module):
26
+ def __init__(self, normalized_shape):
27
+ super(BiasFree_LayerNorm, self).__init__()
28
+ if isinstance(normalized_shape, numbers.Integral):
29
+ normalized_shape = (normalized_shape,)
30
+ normalized_shape = torch.Size(normalized_shape)
31
+
32
+ assert len(normalized_shape) == 1
33
+
34
+ self.weight = nn.Parameter(torch.ones(normalized_shape))
35
+ self.normalized_shape = normalized_shape
36
+
37
+ def forward(self, x):
38
+ sigma = x.var(-1, keepdim=True, unbiased=False)
39
+ return x / torch.sqrt(sigma+1e-5) * self.weight
40
+
41
+ class WithBias_LayerNorm(nn.Module):
42
+ def __init__(self, normalized_shape):
43
+ super(WithBias_LayerNorm, self).__init__()
44
+ if isinstance(normalized_shape, numbers.Integral):
45
+ normalized_shape = (normalized_shape,)
46
+ normalized_shape = torch.Size(normalized_shape)
47
+
48
+ assert len(normalized_shape) == 1
49
+
50
+ self.weight = nn.Parameter(torch.ones(normalized_shape))
51
+ self.bias = nn.Parameter(torch.zeros(normalized_shape))
52
+ self.normalized_shape = normalized_shape
53
+
54
+ def forward(self, x):
55
+ mu = x.mean(-1, keepdim=True)
56
+ sigma = x.var(-1, keepdim=True, unbiased=False)
57
+ return (x - mu) / torch.sqrt(sigma+1e-5) * self.weight + self.bias
58
+
59
+
60
+ class LayerNorm(nn.Module):
61
+ def __init__(self, dim, LayerNorm_type):
62
+ super(LayerNorm, self).__init__()
63
+ if LayerNorm_type =='BiasFree':
64
+ self.body = BiasFree_LayerNorm(dim)
65
+ else:
66
+ self.body = WithBias_LayerNorm(dim)
67
+
68
+ def forward(self, x):
69
+ h, w = x.shape[-2:]
70
+ return to_4d(self.body(to_3d(x)), h, w)
71
+
72
+
73
+
74
+ ##########################################################################
75
+ ## Gated-Dconv Feed-Forward Network (GDFN)
76
+ class FeedForward(nn.Module):
77
+ def __init__(self, dim, ffn_expansion_factor, bias):
78
+ super(FeedForward, self).__init__()
79
+
80
+ hidden_features = int(dim*ffn_expansion_factor)
81
+
82
+ self.project_in = nn.Conv2d(dim, hidden_features*2, kernel_size=1, bias=bias)
83
+
84
+ self.dwconv = nn.Conv2d(hidden_features*2, hidden_features*2, kernel_size=3, stride=1, padding=1, groups=hidden_features*2, bias=bias)
85
+
86
+ self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias)
87
+
88
+ def forward(self, x):
89
+ x = self.project_in(x)
90
+ x1, x2 = self.dwconv(x).chunk(2, dim=1)
91
+ x = F.gelu(x1) * x2
92
+ x = self.project_out(x)
93
+ return x
94
+
95
+
96
+
97
+ ##########################################################################
98
+ ## Multi-DConv Head Transposed Self-Attention (MDTA)
99
+ class Attention(nn.Module):
100
+ def __init__(self, dim, num_heads, bias):
101
+ super(Attention, self).__init__()
102
+ self.num_heads = num_heads
103
+ self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))
104
+
105
+ self.qkv = nn.Conv2d(dim, dim*3, kernel_size=1, bias=bias)
106
+ self.qkv_dwconv = nn.Conv2d(dim*3, dim*3, kernel_size=3, stride=1, padding=1, groups=dim*3, bias=bias)
107
+ self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
108
+
109
+
110
+
111
+ def forward(self, x):
112
+ b,c,h,w = x.shape
113
+
114
+ qkv = self.qkv_dwconv(self.qkv(x))
115
+ q,k,v = qkv.chunk(3, dim=1)
116
+
117
+ q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
118
+ k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
119
+ v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
120
+
121
+ q = torch.nn.functional.normalize(q, dim=-1)
122
+ k = torch.nn.functional.normalize(k, dim=-1)
123
+
124
+ attn = (q @ k.transpose(-2, -1)) * self.temperature
125
+ attn = attn.softmax(dim=-1)
126
+
127
+ out = (attn @ v)
128
+
129
+ out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w)
130
+
131
+ out = self.project_out(out)
132
+ return out
133
+
134
+
135
+
136
+ ##########################################################################
137
+ class TransformerBlock(nn.Module):
138
+ def __init__(self, dim, num_heads, ffn_expansion_factor, bias, LayerNorm_type):
139
+ super(TransformerBlock, self).__init__()
140
+
141
+ self.norm1 = LayerNorm(dim, LayerNorm_type)
142
+ self.attn = Attention(dim, num_heads, bias)
143
+ self.norm2 = LayerNorm(dim, LayerNorm_type)
144
+ self.ffn = FeedForward(dim, ffn_expansion_factor, bias)
145
+
146
+ def forward(self, x):
147
+ x = x + self.attn(self.norm1(x))
148
+ x = x + self.ffn(self.norm2(x))
149
+
150
+ return x
151
+
152
+
153
+
154
+ ##########################################################################
155
+ ## Overlapped image patch embedding with 3x3 Conv
156
+ class OverlapPatchEmbed(nn.Module):
157
+ def __init__(self, in_c=3, embed_dim=48, bias=False):
158
+ super(OverlapPatchEmbed, self).__init__()
159
+
160
+ self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=3, stride=1, padding=1, bias=bias)
161
+
162
+ def forward(self, x):
163
+ x = self.proj(x)
164
+
165
+ return x
166
+
167
+
168
+
169
+ ##########################################################################
170
+ ## Resizing modules
171
+ class Downsample(nn.Module):
172
+ def __init__(self, n_feat):
173
+ super(Downsample, self).__init__()
174
+
175
+ self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat//2, kernel_size=3, stride=1, padding=1, bias=False),
176
+ nn.PixelUnshuffle(2))
177
+
178
+ def forward(self, x):
179
+ return self.body(x)
180
+
181
+ class Upsample(nn.Module):
182
+ def __init__(self, n_feat):
183
+ super(Upsample, self).__init__()
184
+
185
+ self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat*2, kernel_size=3, stride=1, padding=1, bias=False),
186
+ nn.PixelShuffle(2))
187
+
188
+ def forward(self, x):
189
+ return self.body(x)
190
+
191
+ ##########################################################################
192
+ class Strip_VSSB(nn.Module):
193
+ def __init__(self, dim, head_num):
194
+ super(Strip_VSSB, self).__init__()
195
+
196
+ self.intra = TransformerBlock(dim=32, num_heads=head_num, ffn_expansion_factor=2.66, bias=False, LayerNorm_type='WithBias')
197
+ self.inter = TransformerBlock(dim=32, num_heads=head_num, ffn_expansion_factor=2.66, bias=False, LayerNorm_type='WithBias')
198
+
199
+ def forward(self, x):
200
+ x = self.intra(x)
201
+ x = self.inter(x)
202
+
203
+ return x
204
+
205
+ ##########################################################################
206
+ ##---------- Restormer -----------------------
207
+ class Restormer(nn.Module):
208
+ def __init__(self,
209
+ inp_channels=3,
210
+ out_channels=3,
211
+ dim = 12,
212
+ num_blocks = [4,6,6,8],
213
+ num_refinement_blocks = 4,
214
+ heads = [1,2,4,8],
215
+ ffn_expansion_factor = 2.66,
216
+ bias = False,
217
+ LayerNorm_type = 'WithBias', ## Other option 'BiasFree'
218
+ dual_pixel_task = False ## True for dual-pixel defocus deblurring only. Also set inp_channels=6
219
+ ):
220
+
221
+ super(Restormer, self).__init__()
222
+
223
+ self.patch_embed = OverlapPatchEmbed(inp_channels, dim)
224
+
225
+ self.encoder_level1 = nn.Sequential(*[TransformerBlock(dim=dim, num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])])
226
+
227
+ self.down1_2 = Downsample(dim) ## From Level 1 to Level 2
228
+ self.encoder_level2 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[1])])
229
+
230
+ self.down2_3 = Downsample(int(dim*2**1)) ## From Level 2 to Level 3
231
+ self.encoder_level3 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[2])])
232
+
233
+ self.down3_4 = Downsample(int(dim*2**2)) ## From Level 3 to Level 4
234
+ self.latent = nn.Sequential(*[TransformerBlock(dim=int(dim*2**3), num_heads=heads[3], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[3])])
235
+
236
+ self.up4_3 = Upsample(int(dim*2**3)) ## From Level 4 to Level 3
237
+ self.reduce_chan_level3 = nn.Conv2d(int(dim*2**3), int(dim*2**2), kernel_size=1, bias=bias)
238
+ self.decoder_level3 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[2])])
239
+
240
+
241
+ self.up3_2 = Upsample(int(dim*2**2)) ## From Level 3 to Level 2
242
+ self.reduce_chan_level2 = nn.Conv2d(int(dim*2**2), int(dim*2**1), kernel_size=1, bias=bias)
243
+ self.decoder_level2 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[1])])
244
+
245
+ self.up2_1 = Upsample(int(dim*2**1)) ## From Level 2 to Level 1 (NO 1x1 conv to reduce channels)
246
+
247
+ self.decoder_level1 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])])
248
+
249
+ self.refinement = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_refinement_blocks)])
250
+
251
+ #### For Dual-Pixel Defocus Deblurring Task ####
252
+ self.dual_pixel_task = dual_pixel_task
253
+ if self.dual_pixel_task:
254
+ self.skip_conv = nn.Conv2d(dim, int(dim*2**1), kernel_size=1, bias=bias)
255
+ ###########################
256
+
257
+ self.output = nn.Conv2d(int(dim*2**1), out_channels, kernel_size=3, stride=1, padding=1, bias=bias)
258
+
259
+ def forward(self, inp_img):
260
+
261
+ inp_enc_level1 = self.patch_embed(inp_img)
262
+ out_enc_level1 = self.encoder_level1(inp_enc_level1)
263
+
264
+ inp_enc_level2 = self.down1_2(out_enc_level1)
265
+ out_enc_level2 = self.encoder_level2(inp_enc_level2)
266
+
267
+ inp_enc_level3 = self.down2_3(out_enc_level2)
268
+ out_enc_level3 = self.encoder_level3(inp_enc_level3)
269
+
270
+ inp_enc_level4 = self.down3_4(out_enc_level3)
271
+ latent = self.latent(inp_enc_level4)
272
+
273
+ inp_dec_level3 = self.up4_3(latent)
274
+ inp_dec_level3 = torch.cat([inp_dec_level3, out_enc_level3], 1)
275
+ inp_dec_level3 = self.reduce_chan_level3(inp_dec_level3)
276
+ out_dec_level3 = self.decoder_level3(inp_dec_level3)
277
+
278
+ inp_dec_level2 = self.up3_2(out_dec_level3)
279
+ inp_dec_level2 = torch.cat([inp_dec_level2, out_enc_level2], 1)
280
+ inp_dec_level2 = self.reduce_chan_level2(inp_dec_level2)
281
+ out_dec_level2 = self.decoder_level2(inp_dec_level2)
282
+
283
+ inp_dec_level1 = self.up2_1(out_dec_level2)
284
+ inp_dec_level1 = torch.cat([inp_dec_level1, out_enc_level1], 1)
285
+ out_dec_level1 = self.decoder_level1(inp_dec_level1)
286
+
287
+ out_dec_level1 = self.refinement(out_dec_level1)
288
+
289
+ #### For Dual-Pixel Defocus Deblurring Task ####
290
+ if self.dual_pixel_task:
291
+ out_dec_level1 = out_dec_level1 + self.skip_conv(inp_enc_level1)
292
+ out_dec_level1 = self.output(out_dec_level1)
293
+ ###########################
294
+ else:
295
+ out_dec_level1 = self.output(out_dec_level1) + inp_img
296
+
297
+
298
+ return out_dec_level1
299
+
300
+ #"""
301
+ import time
302
+ start_time = time.time()
303
+ inp = torch.randn(1, 3, 256, 256).cuda()#.to(dtype=torch.float16)
304
+ model = Restormer().cuda()#.to(dtype=torch.float16)
305
+ out = model(inp)
306
+ print(out.shape)
307
+ print("--- %s seconds ---" % (time.time() - start_time))
308
+ pytorch_total_params = sum(p.numel() for p in model.parameters())
309
+ print("--- {num} parameters ---".format(num = pytorch_total_params))
310
+ pytorch_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
311
+ print("--- {num} trainable parameters ---".format(num = pytorch_trainable_params))
312
+ gpu_memmem_usage_bytes = torch.cuda.max_memory_allocated()
313
+ print(gpu_memmem_usage_bytes / 1024 / 1024 / 1024) # 64: 0.97 128: 3.04 256: 11.93; 512: OOM
314
+ #"""
315
+ """
316
+ import torch
317
+ from ptflops import get_model_complexity_info
318
+
319
+ with torch.cuda.device(0):
320
+ net = model
321
+ macs, params = get_model_complexity_info(net, (3, 256, 256), as_strings=True,
322
+ print_per_layer_stat=True, verbose=True)
323
+ print('{:<30} {:<8}'.format('Computational complexity: ', macs)) # 31.97 GMac
324
+ print('{:<30} {:<8}'.format('Number of parameters: ', params)) # 8.37 M
325
+ """
326
+ """
327
+ import time
328
+ start_time = time.time()
329
+ inp = torch.randn(1, 32, 64, 64).cuda()#.to(dtype=torch.float16)
330
+ model = Strip_VSSB(dim=32, head_num=4).cuda()#.to(dtype=torch.float16)
331
+ out = model(inp)
332
+ print(out.shape)
333
+ print("--- %s seconds ---" % (time.time() - start_time))
334
+ pytorch_total_params = sum(p.numel() for p in model.parameters())
335
+ print("--- {num} parameters ---".format(num = pytorch_total_params))
336
+ pytorch_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
337
+ print("--- {num} trainable parameters ---".format(num = pytorch_trainable_params))
338
+ gpu_memmem_usage_bytes = torch.cuda.max_memory_allocated()
339
+ print(gpu_memmem_usage_bytes / 1024 / 1024 / 1024) # 64: 0.16; 128: 0.22; 192: 0.37; 256: 0.56; 512: 2.10;
340
+ """
models/sota/Stripformer.py ADDED
@@ -0,0 +1,429 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import math
4
+
5
+ class Embeddings(nn.Module):
6
+ def __init__(self):
7
+ super(Embeddings, self).__init__()
8
+
9
+ self.activation = nn.LeakyReLU(0.2, True)
10
+
11
+ self.en_layer1_1 = nn.Sequential(
12
+ nn.Conv2d(3, 64, kernel_size=3, padding=1),
13
+ self.activation,
14
+ )
15
+ self.en_layer1_2 = nn.Sequential(
16
+ nn.Conv2d(64, 64, kernel_size=3, padding=1),
17
+ self.activation,
18
+ nn.Conv2d(64, 64, kernel_size=3, padding=1))
19
+ self.en_layer1_3 = nn.Sequential(
20
+ nn.Conv2d(64, 64, kernel_size=3, padding=1),
21
+ self.activation,
22
+ nn.Conv2d(64, 64, kernel_size=3, padding=1))
23
+ self.en_layer1_4 = nn.Sequential(
24
+ nn.Conv2d(64, 64, kernel_size=3, padding=1),
25
+ self.activation,
26
+ nn.Conv2d(64, 64, kernel_size=3, padding=1))
27
+
28
+ self.en_layer2_1 = nn.Sequential(
29
+ nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
30
+ self.activation,
31
+ )
32
+ self.en_layer2_2 = nn.Sequential(
33
+ nn.Conv2d(128, 128, kernel_size=3, padding=1),
34
+ self.activation,
35
+ nn.Conv2d(128, 128, kernel_size=3, padding=1))
36
+ self.en_layer2_3 = nn.Sequential(
37
+ nn.Conv2d(128, 128, kernel_size=3, padding=1),
38
+ self.activation,
39
+ nn.Conv2d(128, 128, kernel_size=3, padding=1))
40
+ self.en_layer2_4 = nn.Sequential(
41
+ nn.Conv2d(128, 128, kernel_size=3, padding=1),
42
+ self.activation,
43
+ nn.Conv2d(128, 128, kernel_size=3, padding=1))
44
+
45
+
46
+ self.en_layer3_1 = nn.Sequential(
47
+ nn.Conv2d(128, 320, kernel_size=3, stride=2, padding=1),
48
+ self.activation,
49
+ )
50
+
51
+
52
+ def forward(self, x):
53
+
54
+ hx = self.en_layer1_1(x)
55
+ hx = self.activation(self.en_layer1_2(hx) + hx)
56
+ hx = self.activation(self.en_layer1_3(hx) + hx)
57
+ hx = self.activation(self.en_layer1_4(hx) + hx)
58
+ residual_1 = hx
59
+ hx = self.en_layer2_1(hx)
60
+ hx = self.activation(self.en_layer2_2(hx) + hx)
61
+ hx = self.activation(self.en_layer2_3(hx) + hx)
62
+ hx = self.activation(self.en_layer2_4(hx) + hx)
63
+ residual_2 = hx
64
+ hx = self.en_layer3_1(hx)
65
+
66
+ return hx, residual_1, residual_2
67
+
68
+
69
+ class Embeddings_output(nn.Module):
70
+ def __init__(self):
71
+ super(Embeddings_output, self).__init__()
72
+
73
+ self.activation = nn.LeakyReLU(0.2, True)
74
+
75
+ self.de_layer3_1 = nn.Sequential(
76
+ nn.ConvTranspose2d(320, 192, kernel_size=4, stride=2, padding=1),
77
+ self.activation,
78
+ )
79
+ head_num = 3
80
+ dim = 192
81
+
82
+ self.de_layer2_2 = nn.Sequential(
83
+ nn.Conv2d(192+128, 192, kernel_size=1, padding=0),
84
+ self.activation,
85
+ )
86
+
87
+ self.de_block_1 = Intra_SA(dim, head_num)
88
+ self.de_block_2 = Inter_SA(dim, head_num)
89
+ self.de_block_3 = Intra_SA(dim, head_num)
90
+ self.de_block_4 = Inter_SA(dim, head_num)
91
+ self.de_block_5 = Intra_SA(dim, head_num)
92
+ self.de_block_6 = Inter_SA(dim, head_num)
93
+
94
+
95
+ self.de_layer2_1 = nn.Sequential(
96
+ nn.ConvTranspose2d(192, 64, kernel_size=4, stride=2, padding=1),
97
+ self.activation,
98
+ )
99
+
100
+ self.de_layer1_3 = nn.Sequential(
101
+ nn.Conv2d(128, 64, kernel_size=1, padding=0),
102
+ self.activation,
103
+ nn.Conv2d(64, 64, kernel_size=3, padding=1))
104
+ self.de_layer1_2 = nn.Sequential(
105
+ nn.Conv2d(64, 64, kernel_size=3, padding=1),
106
+ self.activation,
107
+ nn.Conv2d(64, 64, kernel_size=3, padding=1))
108
+ self.de_layer1_1 = nn.Sequential(
109
+ nn.Conv2d(64, 3, kernel_size=3, padding=1),
110
+ self.activation
111
+ )
112
+
113
+ def forward(self, x, residual_1, residual_2):
114
+
115
+
116
+ hx = self.de_layer3_1(x)
117
+
118
+ hx = self.de_layer2_2(torch.cat((hx, residual_2), dim = 1))
119
+ hx = self.de_block_1(hx)
120
+ hx = self.de_block_2(hx)
121
+ hx = self.de_block_3(hx)
122
+ hx = self.de_block_4(hx)
123
+ hx = self.de_block_5(hx)
124
+ hx = self.de_block_6(hx)
125
+ hx = self.de_layer2_1(hx)
126
+
127
+ hx = self.activation(self.de_layer1_3(torch.cat((hx, residual_1), dim = 1)) + hx)
128
+ hx = self.activation(self.de_layer1_2(hx) + hx)
129
+ hx = self.de_layer1_1(hx)
130
+
131
+ return hx
132
+
133
+ class Attention(nn.Module):
134
+ def __init__(self, head_num):
135
+ super(Attention, self).__init__()
136
+ self.num_attention_heads = head_num
137
+ self.softmax = nn.Softmax(dim=-1)
138
+
139
+ def transpose_for_scores(self, x):
140
+ B, N, C = x.size()
141
+ attention_head_size = int(C / self.num_attention_heads)
142
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, attention_head_size)
143
+ x = x.view(*new_x_shape)
144
+ return x.permute(0, 2, 1, 3).contiguous()
145
+
146
+ def forward(self, query_layer, key_layer, value_layer):
147
+ B, N, C = query_layer.size()
148
+ query_layer = self.transpose_for_scores(query_layer)
149
+ key_layer = self.transpose_for_scores(key_layer)
150
+ value_layer = self.transpose_for_scores(value_layer)
151
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
152
+ _, _, _, d = query_layer.size()
153
+ attention_scores = attention_scores / math.sqrt(d)
154
+ attention_probs = self.softmax(attention_scores)
155
+ context_layer = torch.matmul(attention_probs, value_layer)
156
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
157
+ new_context_layer_shape = context_layer.size()[:-2] + (C,)
158
+ attention_out = context_layer.view(*new_context_layer_shape)
159
+
160
+ return attention_out
161
+
162
+
163
+ class Mlp(nn.Module):
164
+ def __init__(self, hidden_size):
165
+ super(Mlp, self).__init__()
166
+ self.fc1 = nn.Linear(hidden_size, 4*hidden_size)
167
+ self.fc2 = nn.Linear(4*hidden_size, hidden_size)
168
+ self.act_fn = torch.nn.functional.gelu
169
+ self._init_weights()
170
+
171
+ def _init_weights(self):
172
+ nn.init.xavier_uniform_(self.fc1.weight)
173
+ nn.init.xavier_uniform_(self.fc2.weight)
174
+ nn.init.normal_(self.fc1.bias, std=1e-6)
175
+ nn.init.normal_(self.fc2.bias, std=1e-6)
176
+
177
+ def forward(self, x):
178
+ x = self.fc1(x)
179
+ x = self.act_fn(x)
180
+ x = self.fc2(x)
181
+ return x
182
+
183
+
184
+ # CPE (Conditional Positional Embedding)
185
+ class PEG(nn.Module):
186
+ def __init__(self, hidden_size):
187
+ super(PEG, self).__init__()
188
+ self.PEG = nn.Conv2d(hidden_size, hidden_size, kernel_size=3, padding=1, groups=hidden_size)
189
+
190
+ def forward(self, x):
191
+ x = self.PEG(x) + x
192
+ return x
193
+
194
+
195
+ class Intra_SA(nn.Module):
196
+ def __init__(self, dim, head_num):
197
+ super(Intra_SA, self).__init__()
198
+ self.hidden_size = dim // 2
199
+ self.head_num = head_num
200
+ self.attention_norm = nn.LayerNorm(dim)
201
+ self.conv_input = nn.Conv2d(dim, dim, kernel_size=1, padding=0)
202
+ self.qkv_local_h = nn.Linear(self.hidden_size, self.hidden_size * 3) # qkv_h
203
+ self.qkv_local_v = nn.Linear(self.hidden_size, self.hidden_size * 3) # qkv_v
204
+ self.fuse_out = nn.Conv2d(dim, dim, kernel_size=1, padding=0)
205
+ self.ffn_norm = nn.LayerNorm(dim)
206
+ self.ffn = Mlp(dim)
207
+ self.attn = Attention(head_num=self.head_num)
208
+ self.PEG = PEG(dim)
209
+ def forward(self, x):
210
+ h = x
211
+ B, C, H, W = x.size()
212
+
213
+ x = x.view(B, C, H*W).permute(0, 2, 1).contiguous()
214
+ x = self.attention_norm(x).permute(0, 2, 1).contiguous()
215
+ x = x.view(B, C, H, W)
216
+
217
+ x_input = torch.chunk(self.conv_input(x), 2, dim=1)
218
+ feature_h = (x_input[0]).permute(0, 2, 3, 1).contiguous()
219
+ feature_h = feature_h.view(B * H, W, C//2)
220
+ feature_v = (x_input[1]).permute(0, 3, 2, 1).contiguous()
221
+ feature_v = feature_v.view(B * W, H, C//2)
222
+ qkv_h = torch.chunk(self.qkv_local_h(feature_h), 3, dim=2)
223
+ qkv_v = torch.chunk(self.qkv_local_v(feature_v), 3, dim=2)
224
+ q_h, k_h, v_h = qkv_h[0], qkv_h[1], qkv_h[2]
225
+ q_v, k_v, v_v = qkv_v[0], qkv_v[1], qkv_v[2]
226
+
227
+ if H == W:
228
+ query = torch.cat((q_h, q_v), dim=0)
229
+ key = torch.cat((k_h, k_v), dim=0)
230
+ value = torch.cat((v_h, v_v), dim=0)
231
+ attention_output = self.attn(query, key, value)
232
+ attention_output = torch.chunk(attention_output, 2, dim=0)
233
+ attention_output_h = attention_output[0]
234
+ attention_output_v = attention_output[1]
235
+ attention_output_h = attention_output_h.view(B, H, W, C//2).permute(0, 3, 1, 2).contiguous()
236
+ attention_output_v = attention_output_v.view(B, W, H, C//2).permute(0, 3, 2, 1).contiguous()
237
+ attn_out = self.fuse_out(torch.cat((attention_output_h, attention_output_v), dim=1))
238
+ else:
239
+ attention_output_h = self.attn(q_h, k_h, v_h)
240
+ attention_output_v = self.attn(q_v, k_v, v_v)
241
+ attention_output_h = attention_output_h.view(B, H, W, C//2).permute(0, 3, 1, 2).contiguous()
242
+ attention_output_v = attention_output_v.view(B, W, H, C//2).permute(0, 3, 2, 1).contiguous()
243
+ attn_out = self.fuse_out(torch.cat((attention_output_h, attention_output_v), dim=1))
244
+
245
+ x = attn_out + h
246
+ x = x.view(B, C, H*W).permute(0, 2, 1).contiguous()
247
+ h = x
248
+ x = self.ffn_norm(x)
249
+ x = self.ffn(x)
250
+ x = x + h
251
+ x = x.permute(0, 2, 1).contiguous()
252
+ x = x.view(B, C, H, W)
253
+
254
+ x = self.PEG(x)
255
+
256
+ return x
257
+
258
+ class Inter_SA(nn.Module):
259
+ def __init__(self,dim, head_num):
260
+ super(Inter_SA, self).__init__()
261
+ self.hidden_size = dim
262
+ self.head_num = head_num
263
+ self.attention_norm = nn.LayerNorm(self.hidden_size)
264
+ self.conv_input = nn.Conv2d(self.hidden_size, self.hidden_size, kernel_size=1, padding=0)
265
+ self.conv_h = nn.Conv2d(self.hidden_size//2, 3 * (self.hidden_size//2), kernel_size=1, padding=0) # qkv_h
266
+ self.conv_v = nn.Conv2d(self.hidden_size//2, 3 * (self.hidden_size//2), kernel_size=1, padding=0) # qkv_v
267
+ self.ffn_norm = nn.LayerNorm(self.hidden_size)
268
+ self.ffn = Mlp(self.hidden_size)
269
+ self.fuse_out = nn.Conv2d(self.hidden_size, self.hidden_size, kernel_size=1, padding=0)
270
+ self.attn = Attention(head_num=self.head_num)
271
+ self.PEG = PEG(dim)
272
+
273
+ def forward(self, x):
274
+ h = x
275
+ B, C, H, W = x.size()
276
+
277
+ x = x.view(B, C, H*W).permute(0, 2, 1).contiguous()
278
+ x = self.attention_norm(x).permute(0, 2, 1).contiguous()
279
+ x = x.view(B, C, H, W)
280
+ #print(x.shape)
281
+
282
+ x_input = torch.chunk(self.conv_input(x), 2, dim=1)
283
+ feature_h = torch.chunk(self.conv_h(x_input[0]), 3, dim=1)
284
+ feature_v = torch.chunk(self.conv_v(x_input[1]), 3, dim=1)
285
+ query_h, key_h, value_h = feature_h[0], feature_h[1], feature_h[2]
286
+ query_v, key_v, value_v = feature_v[0], feature_v[1], feature_v[2]
287
+
288
+ horizontal_groups = torch.cat((query_h, key_h, value_h), dim=0)
289
+ horizontal_groups = horizontal_groups.permute(0, 2, 1, 3).contiguous()
290
+ horizontal_groups = horizontal_groups.view(3*B, H, -1)
291
+ horizontal_groups = torch.chunk(horizontal_groups, 3, dim=0)
292
+ query_h, key_h, value_h = horizontal_groups[0], horizontal_groups[1], horizontal_groups[2]
293
+
294
+ vertical_groups = torch.cat((query_v, key_v, value_v), dim=0)
295
+ vertical_groups = vertical_groups.permute(0, 3, 1, 2).contiguous()
296
+ vertical_groups = vertical_groups.view(3*B, W, -1)
297
+ vertical_groups = torch.chunk(vertical_groups, 3, dim=0)
298
+ query_v, key_v, value_v = vertical_groups[0], vertical_groups[1], vertical_groups[2]
299
+
300
+
301
+ if H == W:
302
+ query = torch.cat((query_h, query_v), dim=0)
303
+ key = torch.cat((key_h, key_v), dim=0)
304
+ value = torch.cat((value_h, value_v), dim=0)
305
+ attention_output = self.attn(query, key, value)
306
+ attention_output = torch.chunk(attention_output, 2, dim=0)
307
+ attention_output_h = attention_output[0]
308
+ attention_output_v = attention_output[1]
309
+ attention_output_h = attention_output_h.view(B, H, C//2, W).permute(0, 2, 1, 3).contiguous()
310
+ attention_output_v = attention_output_v.view(B, W, C//2, H).permute(0, 2, 3, 1).contiguous()
311
+ attn_out = self.fuse_out(torch.cat((attention_output_h, attention_output_v), dim=1))
312
+ else:
313
+ attention_output_h = self.attn(query_h, key_h, value_h)
314
+ attention_output_v = self.attn(query_v, key_v, value_v)
315
+ attention_output_h = attention_output_h.view(B, H, C//2, W).permute(0, 2, 1, 3).contiguous()
316
+ attention_output_v = attention_output_v.view(B, W, C//2, H).permute(0, 2, 3, 1).contiguous()
317
+ attn_out = self.fuse_out(torch.cat((attention_output_h, attention_output_v), dim=1))
318
+
319
+ x = attn_out + h
320
+ x = x.view(B, C, H*W).permute(0, 2, 1).contiguous()
321
+ h = x
322
+ x = self.ffn_norm(x)
323
+ x = self.ffn(x)
324
+ x = x + h
325
+ x = x.permute(0, 2, 1).contiguous()
326
+ x = x.view(B, C, H, W)
327
+
328
+ x = self.PEG(x)
329
+
330
+ return x
331
+
332
+ ##########################################################################
333
+ class Strip_VSSB(nn.Module):
334
+ def __init__(self, dim, head_num):
335
+ super(Strip_VSSB, self).__init__()
336
+
337
+ self.intra = Intra_SA(dim, head_num)
338
+ self.inter = Inter_SA(dim, head_num)
339
+
340
+ def forward(self, x):
341
+ x = self.intra(x)
342
+ x = self.inter(x)
343
+
344
+ return x
345
+
346
+ class Stripformer(nn.Module):
347
+ def __init__(self):
348
+ super(Stripformer, self).__init__()
349
+
350
+ self.encoder = Embeddings()
351
+ head_num = 5
352
+ dim = 320
353
+ self.Trans_block_1 = Intra_SA(dim, head_num)
354
+ self.Trans_block_2 = Inter_SA(dim, head_num)
355
+ self.Trans_block_3 = Intra_SA(dim, head_num)
356
+ self.Trans_block_4 = Inter_SA(dim, head_num)
357
+ self.Trans_block_5 = Intra_SA(dim, head_num)
358
+ self.Trans_block_6 = Inter_SA(dim, head_num)
359
+ self.Trans_block_7 = Intra_SA(dim, head_num)
360
+ self.Trans_block_8 = Inter_SA(dim, head_num)
361
+ self.Trans_block_9 = Intra_SA(dim, head_num)
362
+ self.Trans_block_10 = Inter_SA(dim, head_num)
363
+ self.Trans_block_11 = Intra_SA(dim, head_num)
364
+ self.Trans_block_12 = Inter_SA(dim, head_num)
365
+ self.decoder = Embeddings_output()
366
+
367
+
368
+ def forward(self, x):
369
+
370
+ hx, residual_1, residual_2 = self.encoder(x)
371
+ hx = self.Trans_block_1(hx)
372
+ hx = self.Trans_block_2(hx)
373
+ hx = self.Trans_block_3(hx)
374
+ hx = self.Trans_block_4(hx)
375
+ hx = self.Trans_block_5(hx)
376
+ hx = self.Trans_block_6(hx)
377
+ hx = self.Trans_block_7(hx)
378
+ hx = self.Trans_block_8(hx)
379
+ hx = self.Trans_block_9(hx)
380
+ hx = self.Trans_block_10(hx)
381
+ hx = self.Trans_block_11(hx)
382
+ hx = self.Trans_block_12(hx)
383
+ hx = self.decoder(hx, residual_1, residual_2)
384
+
385
+ return hx + x
386
+
387
+ #"""
388
+ import time
389
+ start_time = time.time()
390
+ inp = torch.randn(1, 3, 64, 64).cuda()
391
+ model = Stripformer().cuda()
392
+ out = model(inp)
393
+ print(out.shape)
394
+ print("--- %s seconds ---" % (time.time() - start_time))
395
+ pytorch_total_params = sum(p.numel() for p in model.parameters())
396
+ print("--- {num} parameters ---".format(num = pytorch_total_params))
397
+ pytorch_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
398
+ print("--- {num} trainable parameters ---".format(num = pytorch_trainable_params))
399
+ gpu_memmem_usage_bytes = torch.cuda.max_memory_allocated()
400
+ print(gpu_memmem_usage_bytes / 1024 / 1024 / 1024) # 64: 0.37 128: 0.84 -> 256: 3.02 -> 512: 12.55
401
+ #"""
402
+ """
403
+ import torch
404
+ from ptflops import get_model_complexity_info
405
+
406
+ with torch.cuda.device(0):
407
+ net = model
408
+ macs, params = get_model_complexity_info(net, (3, 512, 512), as_strings=True,
409
+ print_per_layer_stat=True, verbose=True)
410
+ print('{:<30} {:<8}'.format('Computational complexity: ', macs)) # 49.79 GMac
411
+ print('{:<30} {:<8}'.format('Number of parameters: ', params)) # 6.06 M
412
+ """
413
+ """
414
+ import time
415
+ start_time = time.time()
416
+ inp = torch.randn(1, 32, 512, 512).cuda().to(dtype=torch.float32)
417
+ model = Strip_VSSB(dim=32, head_num = 4).cuda().to(dtype=torch.float32)
418
+ out = model(inp)
419
+ print(out.shape)
420
+ print("--- %s seconds ---" % (time.time() - start_time))
421
+ pytorch_total_params = sum(p.numel() for p in model.parameters())
422
+ print("--- {num} parameters ---".format(num = pytorch_total_params))
423
+ pytorch_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
424
+ print("--- {num} trainable parameters ---".format(num = pytorch_trainable_params))
425
+ gpu_memmem_usage_bytes = torch.cuda.max_memory_allocated()
426
+ print(gpu_memmem_usage_bytes / 1024 / 1024 / 1024) # 128: 0.84 -> 256: 3.02 -> 512: 12.55
427
+ """
428
+
429
+
models/sota/XYScanNet.py ADDED
@@ -0,0 +1,754 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numbers
2
+ import math
3
+ from typing import Optional
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from torch import Tensor
9
+
10
+ from einops import rearrange, repeat
11
+
12
+ from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn
13
+
14
+ try:
15
+ from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
16
+ except ImportError:
17
+ causal_conv1d_fn, causal_conv1d_update = None, None
18
+
19
+ try:
20
+ from mamba_ssm.ops.triton.selective_state_update import selective_state_update
21
+ except ImportError:
22
+ selective_state_update = None
23
+
24
+ try:
25
+ from mamba_ssm.ops.triton.layer_norm import RMSNorm, layer_norm_fn, rms_norm_fn
26
+ except ImportError:
27
+ RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None
28
+
29
+
30
+ def to_3d(x):
31
+ return rearrange(x, 'b c h w -> b (h w) c')
32
+
33
+
34
+ def to_4d(x, h, w):
35
+ return rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
36
+
37
+
38
+ class BiasFree_LayerNorm(nn.Module):
39
+ def __init__(self, normalized_shape):
40
+ super(BiasFree_LayerNorm, self).__init__()
41
+ if isinstance(normalized_shape, numbers.Integral):
42
+ normalized_shape = (normalized_shape,)
43
+ normalized_shape = torch.Size(normalized_shape)
44
+
45
+ assert len(normalized_shape) == 1
46
+
47
+ self.weight = nn.Parameter(torch.ones(normalized_shape))
48
+ self.normalized_shape = normalized_shape
49
+
50
+ def forward(self, x):
51
+ sigma = x.var(-1, keepdim=True, unbiased=False)
52
+ return x / torch.sqrt(sigma + 1e-5) * self.weight
53
+
54
+
55
+ class WithBias_LayerNorm(nn.Module):
56
+ def __init__(self, normalized_shape):
57
+ super(WithBias_LayerNorm, self).__init__()
58
+ if isinstance(normalized_shape, numbers.Integral):
59
+ normalized_shape = (normalized_shape,)
60
+ normalized_shape = torch.Size(normalized_shape)
61
+
62
+ assert len(normalized_shape) == 1
63
+
64
+ self.weight = nn.Parameter(torch.ones(normalized_shape))
65
+ self.bias = nn.Parameter(torch.zeros(normalized_shape))
66
+ self.normalized_shape = normalized_shape
67
+
68
+ def forward(self, x):
69
+ mu = x.mean(-1, keepdim=True)
70
+ sigma = x.var(-1, keepdim=True, unbiased=False)
71
+ return (x - mu) / torch.sqrt(sigma + 1e-5) * self.weight + self.bias
72
+
73
+
74
+ class LayerNorm(nn.Module):
75
+ def __init__(self, dim, LayerNorm_type):
76
+ super(LayerNorm, self).__init__()
77
+ if LayerNorm_type == 'BiasFree':
78
+ self.body = BiasFree_LayerNorm(dim)
79
+ else:
80
+ self.body = WithBias_LayerNorm(dim)
81
+
82
+ def forward(self, x):
83
+ h, w = x.shape[-2:]
84
+ return to_4d(self.body(to_3d(x)), h, w)
85
+
86
+ ##########################################################################
87
+ def conv(in_channels, out_channels, kernel_size, bias=False, stride = 1):
88
+ return nn.Conv2d(
89
+ in_channels, out_channels, kernel_size,
90
+ padding=(kernel_size//2), bias=bias, stride = stride)
91
+
92
+
93
+ """
94
+ Borrow from "https://github.com/state-spaces/mamba.git"
95
+ @article{mamba,
96
+ title={Mamba: Linear-Time Sequence Modeling with Selective State Spaces},
97
+ author={Gu, Albert and Dao, Tri},
98
+ journal={arXiv preprint arXiv:2312.00752},
99
+ year={2023}
100
+ }
101
+ """
102
+ class Mamba(nn.Module):
103
+ def __init__(
104
+ self,
105
+ d_model,
106
+ d_state=16,
107
+ d_conv=4,
108
+ expand=2,
109
+ dt_rank="auto",
110
+ dt_min=0.001,
111
+ dt_max=0.1,
112
+ dt_init="random",
113
+ dt_scale=1.0,
114
+ dt_init_floor=1e-4,
115
+ conv_bias=True,
116
+ bias=False,
117
+ use_fast_path=True, # Fused kernel options
118
+ layer_idx=None,
119
+ device=None,
120
+ dtype=None,
121
+ ):
122
+ factory_kwargs = {"device": device, "dtype": dtype}
123
+ super().__init__()
124
+ self.d_model = d_model
125
+ self.d_state = d_state
126
+ self.d_conv = d_conv
127
+ self.expand = expand
128
+ self.d_inner = int(self.expand * self.d_model)
129
+ self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank
130
+ self.use_fast_path = use_fast_path
131
+ self.layer_idx = layer_idx
132
+
133
+ self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=bias, **factory_kwargs)
134
+
135
+ self.conv1d = nn.Conv1d(
136
+ in_channels=self.d_inner,
137
+ out_channels=self.d_inner,
138
+ bias=conv_bias,
139
+ kernel_size=d_conv,
140
+ groups=self.d_inner,
141
+ padding=d_conv - 1,
142
+ **factory_kwargs,
143
+ )
144
+
145
+ self.activation = "silu"
146
+ self.act = nn.SiLU()
147
+
148
+ self.x_proj = nn.Linear(
149
+ self.d_inner, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs
150
+ )
151
+ self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True, **factory_kwargs)
152
+
153
+ # Initialize special dt projection to preserve variance at initialization
154
+ dt_init_std = self.dt_rank**-0.5 * dt_scale
155
+ if dt_init == "constant":
156
+ nn.init.constant_(self.dt_proj.weight, dt_init_std)
157
+ elif dt_init == "random":
158
+ nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std)
159
+ else:
160
+ raise NotImplementedError
161
+
162
+ # Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max
163
+ dt = torch.exp(
164
+ torch.rand(self.d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min))
165
+ + math.log(dt_min)
166
+ ).clamp(min=dt_init_floor)
167
+ # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
168
+ inv_dt = dt + torch.log(-torch.expm1(-dt))
169
+ with torch.no_grad():
170
+ self.dt_proj.bias.copy_(inv_dt)
171
+ # Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit
172
+ self.dt_proj.bias._no_reinit = True
173
+
174
+ # S4D real initialization
175
+ A = repeat(
176
+ torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device),
177
+ "n -> d n",
178
+ d=self.d_inner,
179
+ ).contiguous()
180
+ A_log = torch.log(A) # Keep A_log in fp32
181
+ self.A_log = nn.Parameter(A_log)
182
+ self.A_log._no_weight_decay = True
183
+
184
+ # D "skip" parameter
185
+ self.D = nn.Parameter(torch.ones(self.d_inner, device=device)) # Keep in fp32
186
+ self.D._no_weight_decay = True
187
+
188
+ self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs)
189
+
190
+ def forward(self, hidden_states, inference_params=None):
191
+ """
192
+ hidden_states: (B, L, D)
193
+ Returns: same shape as hidden_states
194
+ """
195
+ batch, seqlen, dim = hidden_states.shape
196
+
197
+ conv_state, ssm_state = None, None
198
+ if inference_params is not None:
199
+ conv_state, ssm_state = self._get_states_from_cache(inference_params, batch)
200
+ if inference_params.seqlen_offset > 0:
201
+ # The states are updated inplace
202
+ out, _, _ = self.step(hidden_states, conv_state, ssm_state)
203
+ return out
204
+
205
+ # We do matmul and transpose BLH -> HBL at the same time
206
+ xz = rearrange(
207
+ self.in_proj.weight @ rearrange(hidden_states, "b l d -> d (b l)"),
208
+ "d (b l) -> b d l",
209
+ l=seqlen,
210
+ )
211
+ if self.in_proj.bias is not None:
212
+ xz = xz + rearrange(self.in_proj.bias.to(dtype=xz.dtype), "d -> d 1")
213
+
214
+ A = -torch.exp(self.A_log.float()) # (d_inner, d_state)
215
+ # In the backward pass we write dx and dz next to each other to avoid torch.cat
216
+ if self.use_fast_path and causal_conv1d_fn is not None and inference_params is None: # Doesn't support outputting the states
217
+ out = mamba_inner_fn(
218
+ xz,
219
+ self.conv1d.weight,
220
+ self.conv1d.bias,
221
+ self.x_proj.weight,
222
+ self.dt_proj.weight,
223
+ self.out_proj.weight,
224
+ self.out_proj.bias,
225
+ A,
226
+ None, # input-dependent B
227
+ None, # input-dependent C
228
+ self.D.float(),
229
+ delta_bias=self.dt_proj.bias.float(),
230
+ delta_softplus=True,
231
+ )
232
+ else:
233
+ x, z = xz.chunk(2, dim=1)
234
+ # Compute short convolution
235
+ if conv_state is not None:
236
+ # If we just take x[:, :, -self.d_conv :], it will error if seqlen < self.d_conv
237
+ # Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise.
238
+ conv_state.copy_(F.pad(x, (self.d_conv - x.shape[-1], 0))) # Update state (B D W)
239
+ if causal_conv1d_fn is None:
240
+ x = self.act(self.conv1d(x)[..., :seqlen])
241
+ else:
242
+ assert self.activation in ["silu", "swish"]
243
+ x = causal_conv1d_fn(
244
+ x=x,
245
+ weight=rearrange(self.conv1d.weight, "d 1 w -> d w"),
246
+ bias=self.conv1d.bias,
247
+ activation=self.activation,
248
+ )
249
+
250
+ # We're careful here about the layout, to avoid extra transposes.
251
+ # We want dt to have d as the slowest moving dimension
252
+ # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
253
+ x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d")) # (bl d)
254
+ dt, B, C = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1)
255
+ dt = self.dt_proj.weight @ dt.t()
256
+ dt = rearrange(dt, "d (b l) -> b d l", l=seqlen)
257
+ B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
258
+ C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
259
+ assert self.activation in ["silu", "swish"]
260
+ y = selective_scan_fn(
261
+ x,
262
+ dt,
263
+ A,
264
+ B,
265
+ C,
266
+ self.D.float(),
267
+ z=z,
268
+ delta_bias=self.dt_proj.bias.float(),
269
+ delta_softplus=True,
270
+ return_last_state=ssm_state is not None,
271
+ )
272
+ if ssm_state is not None:
273
+ y, last_state = y
274
+ ssm_state.copy_(last_state)
275
+ y = rearrange(y, "b d l -> b l d")
276
+ out = self.out_proj(y)
277
+ return out
278
+
279
+ def step(self, hidden_states, conv_state, ssm_state):
280
+ dtype = hidden_states.dtype
281
+ assert hidden_states.shape[1] == 1, "Only support decoding with 1 token at a time for now"
282
+ xz = self.in_proj(hidden_states.squeeze(1)) # (B 2D)
283
+ x, z = xz.chunk(2, dim=-1) # (B D)
284
+
285
+ # Conv step
286
+ if causal_conv1d_update is None:
287
+ conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W)
288
+ conv_state[:, :, -1] = x
289
+ x = torch.sum(conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1) # (B D)
290
+ if self.conv1d.bias is not None:
291
+ x = x + self.conv1d.bias
292
+ x = self.act(x).to(dtype=dtype)
293
+ else:
294
+ x = causal_conv1d_update(
295
+ x,
296
+ conv_state,
297
+ rearrange(self.conv1d.weight, "d 1 w -> d w"),
298
+ self.conv1d.bias,
299
+ self.activation,
300
+ )
301
+
302
+ x_db = self.x_proj(x) # (B dt_rank+2*d_state)
303
+ dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1)
304
+ # Don't add dt_bias here
305
+ dt = F.linear(dt, self.dt_proj.weight) # (B d_inner)
306
+ A = -torch.exp(self.A_log.float()) # (d_inner, d_state)
307
+
308
+ # SSM step
309
+ if selective_state_update is None:
310
+ # Discretize A and B
311
+ dt = F.softplus(dt + self.dt_proj.bias.to(dtype=dt.dtype))
312
+ dA = torch.exp(torch.einsum("bd,dn->bdn", dt, A))
313
+ dB = torch.einsum("bd,bn->bdn", dt, B)
314
+ ssm_state.copy_(ssm_state * dA + rearrange(x, "b d -> b d 1") * dB)
315
+ y = torch.einsum("bdn,bn->bd", ssm_state.to(dtype), C)
316
+ y = y + self.D.to(dtype) * x
317
+ y = y * self.act(z) # (B D)
318
+ else:
319
+ y = selective_state_update(
320
+ ssm_state, x, dt, A, B, C, self.D, z=z, dt_bias=self.dt_proj.bias, dt_softplus=True
321
+ )
322
+
323
+ out = self.out_proj(y)
324
+ return out.unsqueeze(1), conv_state, ssm_state
325
+
326
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
327
+ device = self.out_proj.weight.device
328
+ conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype
329
+ conv_state = torch.zeros(
330
+ batch_size, self.d_model * self.expand, self.d_conv, device=device, dtype=conv_dtype
331
+ )
332
+ ssm_dtype = self.dt_proj.weight.dtype if dtype is None else dtype
333
+ # ssm_dtype = torch.float32
334
+ ssm_state = torch.zeros(
335
+ batch_size, self.d_model * self.expand, self.d_state, device=device, dtype=ssm_dtype
336
+ )
337
+ return conv_state, ssm_state
338
+
339
+ def _get_states_from_cache(self, inference_params, batch_size, initialize_states=False):
340
+ assert self.layer_idx is not None
341
+ if self.layer_idx not in inference_params.key_value_memory_dict:
342
+ batch_shape = (batch_size,)
343
+ conv_state = torch.zeros(
344
+ batch_size,
345
+ self.d_model * self.expand,
346
+ self.d_conv,
347
+ device=self.conv1d.weight.device,
348
+ dtype=self.conv1d.weight.dtype,
349
+ )
350
+ ssm_state = torch.zeros(
351
+ batch_size,
352
+ self.d_model * self.expand,
353
+ self.d_state,
354
+ device=self.dt_proj.weight.device,
355
+ dtype=self.dt_proj.weight.dtype,
356
+ # dtype=torch.float32,
357
+ )
358
+ inference_params.key_value_memory_dict[self.layer_idx] = (conv_state, ssm_state)
359
+ else:
360
+ conv_state, ssm_state = inference_params.key_value_memory_dict[self.layer_idx]
361
+ # TODO: What if batch size changes between generation, and we reuse the same states?
362
+ if initialize_states:
363
+ conv_state.zero_()
364
+ ssm_state.zero_()
365
+ return conv_state, ssm_state
366
+
367
+ ##########################################################################
368
+ ## Feed-forward Network
369
+ class FFN(nn.Module):
370
+ def __init__(self, dim, ffn_expansion_factor, bias):
371
+ super(FFN, self).__init__()
372
+
373
+ hidden_features = int(dim*ffn_expansion_factor)
374
+
375
+ self.project_in = nn.Conv2d(dim, hidden_features*2, kernel_size=1, bias=bias)
376
+
377
+ self.dwconv = nn.Conv2d(hidden_features*2, hidden_features*2, kernel_size=3, stride=1, padding=1, groups=hidden_features*2, bias=bias, dilation=1)
378
+
379
+ self.win_size = 8
380
+
381
+ self.modulator = nn.Parameter(torch.ones(self.win_size, self.win_size, dim*2)) # modulator
382
+
383
+ self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias)
384
+
385
+ def forward(self, x):
386
+ b, c, h, w = x.shape
387
+ h1, w1 = h//self.win_size, w//self.win_size
388
+ x = self.project_in(x)
389
+ x = self.dwconv(x)
390
+ x_win = rearrange(x, 'b c (wsh h1) (wsw w1) -> b h1 w1 wsh wsw c', wsh=self.win_size, wsw=self.win_size)
391
+ x_win = x_win * self.modulator
392
+ x = rearrange(x_win, 'b h1 w1 wsh wsw c -> b c (wsh h1) (wsw w1)', wsh=self.win_size, wsw=self.win_size, h1=h1, w1=w1)
393
+ x1, x2 = x.chunk(2, dim=1)
394
+ x = x1 * x2
395
+ x = self.project_out(x)
396
+ return x
397
+
398
+
399
+ ##########################################################################
400
+ ## Gated Depth-wise Feed-forward Network (GDFN)
401
+ class GDFN(nn.Module):
402
+ def __init__(self, dim, ffn_expansion_factor, bias):
403
+ super(GDFN, self).__init__()
404
+
405
+ hidden_features = int(dim*ffn_expansion_factor)
406
+
407
+ self.project_in = nn.Conv2d(dim, hidden_features*2, kernel_size=1, bias=bias)
408
+
409
+ self.dwconv = nn.Conv2d(hidden_features*2, hidden_features*2, kernel_size=3, stride=1, padding=1, groups=hidden_features*2, bias=bias, dilation=1)
410
+
411
+ self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias)
412
+
413
+ def forward(self, x):
414
+ x = self.project_in(x)
415
+ x = self.dwconv(x)
416
+ x1, x2 = x.chunk(2, dim=1)
417
+ x = F.silu(x1) * x2
418
+ x = self.project_out(x)
419
+ return x
420
+
421
+
422
+ ##########################################################################
423
+ ## Overlapped image patch embedding with 3x3 Conv
424
+ class OverlapPatchEmbed(nn.Module):
425
+ def __init__(self, in_c=3, embed_dim=48, bias=False):
426
+ super(OverlapPatchEmbed, self).__init__()
427
+
428
+ self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=3, stride=1, padding=1, bias=bias)
429
+
430
+ def forward(self, x):
431
+ x = self.proj(x)
432
+
433
+ return x
434
+
435
+
436
+ ##########################################################################
437
+ ## Resizing modules
438
+ class Downsample(nn.Module):
439
+ def __init__(self, n_feat):
440
+ super(Downsample, self).__init__()
441
+
442
+ self.body = nn.Sequential(nn.Upsample(scale_factor=0.5, mode='bilinear', align_corners=False),
443
+ nn.Conv2d(n_feat, n_feat * 2, 3, stride=1, padding=1, bias=False))
444
+
445
+ def forward(self, x):
446
+ return self.body(x)
447
+
448
+
449
+ class Upsample(nn.Module):
450
+ def __init__(self, n_feat):
451
+ super(Upsample, self).__init__()
452
+
453
+ self.body = nn.Sequential(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
454
+ nn.Conv2d(n_feat, n_feat // 2, 3, stride=1, padding=1, bias=False))
455
+
456
+ def forward(self, x):
457
+ return self.body(x)
458
+
459
+
460
+ """
461
+ Borrow from "https://github.com/pp00704831/Stripformer-ECCV-2022-.git"
462
+ @inproceedings{Tsai2022Stripformer,
463
+ author = {Fu-Jen Tsai and Yan-Tsung Peng and Yen-Yu Lin and Chung-Chi Tsai and Chia-Wen Lin},
464
+ title = {Stripformer: Strip Transformer for Fast Image Deblurring},
465
+ booktitle = {ECCV},
466
+ year = {2022}
467
+ }
468
+ """
469
+ class Intra_VSSM(nn.Module):
470
+ def __init__(self, dim, vssm_expansion_factor, bias): # gated = True
471
+ super(Intra_VSSM, self).__init__()
472
+ hidden = int(dim*vssm_expansion_factor)
473
+
474
+ self.proj_in = nn.Conv2d(dim, hidden*2, kernel_size=1, bias=bias)
475
+ self.dwconv = nn.Conv2d(hidden*2, hidden*2, kernel_size=3, stride=1, padding=1, groups=hidden*2, bias=bias)
476
+ self.proj_out = nn.Conv2d(hidden, dim, kernel_size=1, bias=bias)
477
+
478
+ self.conv_input = nn.Conv2d(hidden, hidden, kernel_size=1, padding=0, bias=bias)
479
+ self.fuse_out = nn.Conv2d(hidden, hidden, kernel_size=1, padding=0, bias=bias)
480
+ self.mamba = Mamba(d_model=hidden // 2)
481
+
482
+ def forward_core(self, x):
483
+ B, C, H, W = x.size()
484
+
485
+ x_input = torch.chunk(self.conv_input(x), 2, dim=1)
486
+
487
+ feature_h = (x_input[0]).permute(0, 2, 3, 1).contiguous()
488
+ feature_h = feature_h.view(B * H, W, C//2)
489
+
490
+ feature_v = (x_input[1]).permute(0, 3, 2, 1).contiguous()
491
+ feature_v = feature_v.view(B * W, H, C//2)
492
+
493
+ if H == W:
494
+ feature = torch.cat((feature_h, feature_v), dim=0) # B * H * 2, W, C//2
495
+ scan_output = self.mamba(feature)
496
+ scan_output = torch.chunk(scan_output, 2, dim=0)
497
+ scan_output_h = scan_output[0]
498
+ scan_output_v = scan_output[1]
499
+ else:
500
+ scan_output_h = self.mamba(feature_h)
501
+ scan_output_v = self.mamba(feature_v)
502
+
503
+ scan_output_h = scan_output_h.view(B, H, W, C//2).permute(0, 3, 1, 2).contiguous()
504
+ scan_output_v = scan_output_v.view(B, W, H, C//2).permute(0, 3, 2, 1).contiguous()
505
+ scan_output = self.fuse_out(torch.cat((scan_output_h, scan_output_v), dim=1))
506
+
507
+ return scan_output
508
+
509
+ def forward(self, x):
510
+ x = self.proj_in(x)
511
+ x, x_ = self.dwconv(x).chunk(2, dim=1)
512
+ x = self.forward_core(x)
513
+ x = F.silu(x_) * x
514
+ x = self.proj_out(x)
515
+ return x
516
+
517
+
518
+ class Inter_VSSM(nn.Module):
519
+ def __init__(self, dim, vssm_expansion_factor, bias): # gated = True
520
+ super(Inter_VSSM, self).__init__()
521
+ hidden = int(dim*vssm_expansion_factor)
522
+
523
+ self.proj_in = nn.Conv2d(dim, hidden*2, kernel_size=1, bias=bias)
524
+ self.dwconv = nn.Conv2d(hidden*2, hidden*2, kernel_size=3, stride=1, padding=1, groups=hidden*2, bias=bias)
525
+ self.proj_out = nn.Conv2d(hidden, dim, kernel_size=1, bias=bias)
526
+
527
+ self.avg_pool = nn.AdaptiveAvgPool2d((None,1))
528
+ self.conv_input = nn.Conv2d(hidden, hidden, kernel_size=1, padding=0, bias=bias)
529
+ self.fuse_out = nn.Conv2d(hidden, hidden, kernel_size=1, padding=0, bias=bias)
530
+ self.mamba = Mamba(d_model=hidden // 2)
531
+ self.sigmoid = nn.Sigmoid()
532
+
533
+ def forward_core(self, x):
534
+ B, C, H, W = x.size()
535
+
536
+ x_input = torch.chunk(self.conv_input(x), 2, dim=1) # B, C, H, W
537
+
538
+ feature_h = x_input[0].permute(0, 2, 1, 3).contiguous() # B, H, C//2, W
539
+ feature_h_score = self.avg_pool(feature_h) # B, H, C//2, 1
540
+ feature_h_score = feature_h_score.view(B, H, -1)
541
+
542
+ feature_v = x_input[1].permute(0, 3, 1, 2).contiguous() # B, W, C//2, H
543
+ feature_v_score = self.avg_pool(feature_v) # B, W, C//2, 1
544
+ feature_v_score = feature_v_score.view(B, W, -1)
545
+
546
+ if H == W:
547
+ feature_score = torch.cat((feature_h_score, feature_v_score), dim=0) # B * 2, W or H, C//2
548
+ scan_score = self.mamba(feature_score)
549
+ scan_score = torch.chunk(scan_score, 2, dim=0)
550
+ scan_score_h = scan_score[0]
551
+ scan_score_v = scan_score[1]
552
+ else:
553
+ scan_score_h = self.mamba(feature_h_score)
554
+ scan_score_v = self.mamba(feature_v_score)
555
+
556
+ scan_score_h = self.sigmoid(scan_score_h)
557
+ scan_score_v = self.sigmoid(scan_score_v)
558
+ feature_h = feature_h*scan_score_h[:,:,:,None]
559
+ feature_v = feature_v*scan_score_v[:,:,:,None]
560
+ feature_h = feature_h.view(B, H, C//2, W).permute(0, 2, 1, 3).contiguous()
561
+ feature_v = feature_v.view(B, W, C//2, H).permute(0, 2, 3, 1).contiguous()
562
+ output = self.fuse_out(torch.cat((feature_h, feature_v), dim=1))
563
+
564
+ return output
565
+
566
+ def forward(self, x):
567
+ x = self.proj_in(x)
568
+ x, x_ = self.dwconv(x).chunk(2, dim=1)
569
+ x = self.forward_core(x)
570
+ x = F.silu(x_) * x
571
+ x = self.proj_out(x)
572
+ return x
573
+
574
+
575
+ ##########################################################################
576
+ class Strip_VSSB(nn.Module):
577
+ def __init__(self, dim, vssm_expansion_factor, ffn_expansion_factor, bias=False, ssm=False, LayerNorm_type='WithBias'):
578
+ super(Strip_VSSB, self).__init__()
579
+ self.ssm = ssm
580
+ if self.ssm == True:
581
+ self.norm1_ssm = LayerNorm(dim, LayerNorm_type)
582
+ self.norm2_ssm = LayerNorm(dim, LayerNorm_type)
583
+ self.intra = Intra_VSSM(dim, vssm_expansion_factor, bias)
584
+ self.inter = Inter_VSSM(dim, vssm_expansion_factor, bias)
585
+ self.norm1_ffn = LayerNorm(dim, LayerNorm_type)
586
+ self.norm2_ffn = LayerNorm(dim, LayerNorm_type)
587
+ self.ffn1 = GDFN(dim, ffn_expansion_factor, bias)
588
+ self.ffn2 = GDFN(dim, ffn_expansion_factor, bias)
589
+
590
+ def forward(self, x):
591
+ if self.ssm == True:
592
+ x = x + self.intra(self.norm1_ssm(x))
593
+ x = x + self.ffn1(self.norm1_ffn(x))
594
+ if self.ssm == True:
595
+ x = x + self.inter(self.norm2_ssm(x))
596
+ x = x + self.ffn2(self.norm2_ffn(x))
597
+
598
+ return x
599
+
600
+
601
+ ##########################################################################
602
+ ##---------- Cross-level Feature Fusion by Adding Sigmoid(KL-Div) * Multi-Scale Feat -----------------------
603
+ class CLFF(nn.Module):
604
+ def __init__(self, dim, dim_n1, dim_n2, bias=False):
605
+ super(CLFF, self).__init__()
606
+
607
+ self.conv = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
608
+ self.conv_n1 = nn.Conv2d(dim_n1, dim, kernel_size=1, bias=bias)
609
+ self.conv_n2 = nn.Conv2d(dim_n2, dim, kernel_size=1, bias=bias)
610
+ self.fuse_out1 = nn.Conv2d(dim*2, dim, kernel_size=1, bias=bias)
611
+
612
+ self.log_sigmoid = nn.LogSigmoid()
613
+ self.sigmoid = nn.Sigmoid()
614
+
615
+ def forward(self, x, n1, n2):
616
+ x_ = self.conv(x)
617
+ n1_ = self.conv_n1(n1)
618
+ n2_ = self.conv_n2(n2)
619
+ kl_n1 = F.kl_div(input=self.log_sigmoid(n1_), target=self.log_sigmoid(x_), log_target=True)
620
+ kl_n2 = F.kl_div(input=self.log_sigmoid(n2_), target=self.log_sigmoid(x_), log_target=True)
621
+ #g = self.sigmoid(x_)
622
+ g1 = self.sigmoid(kl_n1)
623
+ g2 = self.sigmoid(kl_n2)
624
+ #x = (1 + g) * x_ + (1 - g) * (g1 * n1_ + g2 * n2_)
625
+ x = self.fuse_out1(torch.cat((x_, g1 * n1_ + g2 * n2_), dim=1))
626
+
627
+ return x
628
+
629
+ ##########################################################################
630
+ ##---------- StripScanNet -----------------------
631
+ class XYScanNet(nn.Module):
632
+ def __init__(self,
633
+ inp_channels=3,
634
+ out_channels=3,
635
+ dim = 24, # 48, 72, 96, 120, 144, default: 72
636
+ num_blocks = [3,3,6],
637
+ vssm_expansion_factor = 1, # 1 or 2
638
+ ffn_expansion_factor = 1, # 1 or 3
639
+ bias = False,
640
+ LayerNorm_type = 'WithBias', ## Other option 'BiasFree'
641
+ ):
642
+
643
+ super(XYScanNet, self).__init__()
644
+
645
+ self.patch_embed = OverlapPatchEmbed(inp_channels, dim)
646
+
647
+ self.encoder_level1 = nn.Sequential(*[Strip_VSSB(dim=dim, vssm_expansion_factor=vssm_expansion_factor, ffn_expansion_factor = ffn_expansion_factor,
648
+ bias=bias, ssm=False, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])])
649
+
650
+ self.down1_2 = Downsample(dim) ## From Level 1 to Level 2
651
+ self.encoder_level2 = nn.Sequential(*[Strip_VSSB(dim=int(dim*2**1), vssm_expansion_factor=vssm_expansion_factor, ffn_expansion_factor = ffn_expansion_factor,
652
+ bias=bias, ssm=False, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[1])])
653
+
654
+ self.down2_3 = Downsample(int(dim*2**1)) ## From Level 2 to Level 3
655
+ self.encoder_level3 = nn.Sequential(*[Strip_VSSB(dim=int(dim*2**2), vssm_expansion_factor=vssm_expansion_factor, ffn_expansion_factor = ffn_expansion_factor,
656
+ bias=bias, ssm=False, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[2])])
657
+
658
+ self.decoder_level3 = nn.Sequential(*[Strip_VSSB(dim=int(dim*2**2), vssm_expansion_factor=vssm_expansion_factor, ffn_expansion_factor = ffn_expansion_factor,
659
+ bias=bias, ssm=True, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[2])])
660
+
661
+ self.up3_2 = Upsample(int(dim*2**2)) ## From Level 3 to Level 2
662
+ self.clff_level2 = CLFF(int(dim*2**1), dim_n1=int(dim*2**0), dim_n2=(dim*2**2), bias=bias)
663
+ self.reduce_chan_level2 = nn.Conv2d(int(dim*2**2), int(dim*2**1), kernel_size=1, bias=bias)
664
+ self.decoder_level2 = nn.Sequential(*[Strip_VSSB(dim=int(dim*2**1), vssm_expansion_factor=vssm_expansion_factor, ffn_expansion_factor = ffn_expansion_factor,
665
+ bias=bias, ssm=True, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[1])])
666
+
667
+ self.up2_1 = Upsample(int(dim*2**1)) ## From Level 2 to Level 1
668
+ self.clff_level1 = CLFF(int(dim*2**0), dim_n1=int(dim*2**1), dim_n2=(dim*2**2), bias=bias)
669
+ self.reduce_chan_level1 = nn.Conv2d(int(dim*2**1), int(dim*2**0), kernel_size=1, bias=bias)
670
+ self.decoder_level1 = nn.Sequential(*[Strip_VSSB(dim=int(dim*2**0), vssm_expansion_factor=vssm_expansion_factor, ffn_expansion_factor = ffn_expansion_factor,
671
+ bias=bias, ssm=True, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])])
672
+
673
+ # self.refinement = nn.Sequential(*[Strip_VSSB(dim=int(dim*2**0), expansion_factor=expansion_factor, bias=bias, ssm=True, LayerNorm_type=LayerNorm_type) for i in range(num_refinement_blocks)])
674
+
675
+ self.output = nn.Conv2d(int(dim*2**0), out_channels, kernel_size=3, stride=1, padding=1, bias=bias)
676
+
677
+ def forward(self, inp_img):
678
+
679
+ # Encoder
680
+ inp_enc_level1 = self.patch_embed(inp_img)
681
+ out_enc_level1 = self.encoder_level1(inp_enc_level1)
682
+ out_enc_level1_2 = F.interpolate(out_enc_level1, scale_factor=0.5) # dim*2, lvl1 down-scaled to lvl2
683
+
684
+ inp_enc_level2 = self.down1_2(out_enc_level1)
685
+ out_enc_level2 = self.encoder_level2(inp_enc_level2)
686
+ out_enc_level2_1 = F.interpolate(out_enc_level2, scale_factor=2) # dim*2, lvl2 up-scaled to lvl1
687
+
688
+ inp_enc_level3 = self.down2_3(out_enc_level2)
689
+ out_enc_level3 = self.encoder_level3(inp_enc_level3)
690
+ out_enc_level3_2 = F.interpolate(out_enc_level3, scale_factor=2) # dim*2**2, lvl3 up-scaled to lvl2 (lvl3->lvl2)
691
+ out_enc_level3_1 = F.interpolate(out_enc_level3_2, scale_factor=2) # dim*2**2, lvl3 up-scaled to lvl1 (lvl3->lvl2->lvl1)
692
+
693
+ out_enc_level1 = self.clff_level1(out_enc_level1, out_enc_level2_1, out_enc_level3_1)
694
+ out_enc_level2 = self.clff_level2(out_enc_level2, out_enc_level1_2, out_enc_level3_2)
695
+
696
+ # Decoder
697
+ out_dec_level3_decomp1 = self.decoder_level3(out_enc_level3)
698
+
699
+ inp_dec_level2_decomp1 = self.up3_2(out_dec_level3_decomp1)
700
+ inp_dec_level2_decomp1 = self.reduce_chan_level2(torch.cat((inp_dec_level2_decomp1, out_enc_level2), dim=1))
701
+ out_dec_level2_decomp1 = self.decoder_level2(inp_dec_level2_decomp1)
702
+
703
+ inp_dec_level1_decomp1 = self.up2_1(out_dec_level2_decomp1)
704
+ inp_dec_level1_decomp1 = self.reduce_chan_level1(torch.cat((inp_dec_level1_decomp1, out_enc_level1), dim=1))
705
+ out_dec_level1_decomp1 = self.decoder_level1(inp_dec_level1_decomp1)
706
+
707
+ out_dec_level1_decomp1 = self.output(out_dec_level1_decomp1)
708
+
709
+ out_dec_level1 = out_dec_level1_decomp1 + inp_img
710
+
711
+
712
+ return out_dec_level1, out_dec_level1_decomp1, None
713
+
714
+ #"""
715
+ import time
716
+ start_time = time.time()
717
+ inp = torch.randn(1, 3, 512, 512).cuda()#.to(dtype=torch.float16)
718
+ model = XYScanNet().cuda()#.to(dtype=torch.float16)
719
+ out = model(inp)[0]
720
+ print(out.shape)
721
+ print("--- %s seconds ---" % (time.time() - start_time))
722
+ pytorch_total_params = sum(p.numel() for p in model.parameters())
723
+ print("--- {num} parameters ---".format(num = pytorch_total_params))
724
+ pytorch_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
725
+ print("--- {num} trainable parameters ---".format(num = pytorch_trainable_params))
726
+ gpu_memmem_usage_bytes = torch.cuda.max_memory_allocated()
727
+ print(gpu_memmem_usage_bytes / 1024 / 1024 / 1024) # 64: 0.61 128: 2.21 256: 8.56; 512: 33.45
728
+ #"""
729
+ """
730
+ import torch
731
+ from ptflops import get_model_complexity_info
732
+
733
+ with torch.cuda.device(0):
734
+ net = model
735
+ macs, params = get_model_complexity_info(net, (3, 256, 256), as_strings=True,
736
+ print_per_layer_stat=True, verbose=True)
737
+ print('{:<30} {:<8}'.format('Computational complexity: ', macs)) # 31.97 GMac
738
+ print('{:<30} {:<8}'.format('Number of parameters: ', params)) # 8.37 M
739
+ """
740
+ """
741
+ import time
742
+ start_time = time.time()
743
+ inp = torch.randn(1, 128, 64, 64).cuda()#.to(dtype=torch.float16)
744
+ model = Strip_VSSB(dim=128, expansion_factor=1).cuda()#.to(dtype=torch.float16)
745
+ out = model(inp)
746
+ print(out.shape)
747
+ print("--- %s seconds ---" % (time.time() - start_time))
748
+ pytorch_total_params = sum(p.numel() for p in model.parameters())
749
+ print("--- {num} parameters ---".format(num = pytorch_total_params))
750
+ pytorch_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
751
+ print("--- {num} trainable parameters ---".format(num = pytorch_trainable_params))
752
+ gpu_memmem_usage_bytes = torch.cuda.max_memory_allocated()
753
+ print(gpu_memmem_usage_bytes / 1024 / 1024 / 1024) # 128: 0.16 256: 0.24 512: 0.65
754
+ """
out/Results.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ testing results are created in this folder
predict_GoPro_test_results.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import print_function
2
+ import numpy as np
3
+ import torch
4
+ import cv2
5
+ import yaml
6
+ import os
7
+ from torch.autograd import Variable
8
+ from models.networks import get_generator
9
+ import torchvision
10
+ import time
11
+ import argparse
12
+ import torch.nn.functional as F
13
+
14
+ def get_args():
15
+ parser = argparse.ArgumentParser('Test an image')
16
+ parser.add_argument('--job_name', default='xyscannet',
17
+ type=str, help='current job s name')
18
+ return parser.parse_args()
19
+
20
+ def print_max_gpu_usage():
21
+ """Prints the maximum GPU memory usage in GB."""
22
+ max_memory = torch.cuda.max_memory_allocated()
23
+ max_memory_in_gb = max_memory / (1024 ** 3) # Convert bytes to GB
24
+ print(f"Maximum GPU memory usage during test: {max_memory_in_gb:.2f} GB")
25
+
26
+ if __name__ == '__main__':
27
+ # optionally reset gpu
28
+ #torch.cuda.reset_max_memory_allocated()
29
+ args = get_args()
30
+ #with open(os.path.join('config/', args.job_name, 'config_stage2.yaml'), 'r') as cfg: # change the CFG name to test different models: pretrained, gopro, refined, stage1, stage2
31
+ # config = yaml.safe_load(cfg)
32
+ with open(os.path.join('config/', args.job_name, 'config_stage2.yaml'), 'r') as cfg: # change the CFG name to test different models: pretrained, gopro, refined, stage1, stage2
33
+ config = yaml.safe_load(cfg)
34
+ blur_path = '/mnt/g/RESEARCH/PHD/Motion_Deblurred/datasets/GOPRO_/test/testA'
35
+ out_path = os.path.join('results', args.job_name, 'images')
36
+ weights_path = os.path.join('results', args.job_name, 'models', 'best_{}.pth'.format(config['experiment_desc'])) # change the model name to test different phases: final/best
37
+ if not os.path.isdir(out_path):
38
+ os.mkdir(out_path)
39
+ model = get_generator(config['model'])
40
+ model.load_state_dict(torch.load(weights_path))
41
+ model = model.cuda()
42
+ #model.eval()
43
+
44
+ test_time = 0
45
+ iteration = 0
46
+ total_image_number = 1111
47
+
48
+ # warm-up
49
+ warm_up = 0
50
+ print('Hardware warm-up')
51
+ for file in os.listdir(blur_path):
52
+ for img_name in os.listdir(blur_path + '/' + file):
53
+ warm_up += 1
54
+ img = cv2.imread(blur_path + '/' + file + '/' + img_name)
55
+ img_tensor = torch.from_numpy(np.transpose(img / 255, (2, 0, 1)).astype('float32')) - 0.5
56
+ with torch.no_grad():
57
+ img_tensor = Variable(img_tensor.unsqueeze(0)).cuda()
58
+ result_image, decomp1, decomp2 = model(img_tensor)
59
+ #result_image = model(img_tensor)
60
+ if warm_up == 20:
61
+ break
62
+ break
63
+
64
+ for file in os.listdir(blur_path):
65
+ if not os.path.isdir(out_path + '/' + file):
66
+ os.mkdir(out_path + '/' + file)
67
+ for img_name in os.listdir(blur_path + '/' + file):
68
+ img = cv2.imread(blur_path + '/' + file + '/' + img_name)
69
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
70
+ img_tensor = torch.from_numpy(np.transpose(img / 255, (2, 0, 1)).astype('float32')) - 0.5
71
+ with torch.no_grad():
72
+ iteration += 1
73
+ img_tensor = Variable(img_tensor.unsqueeze(0)).cuda()
74
+
75
+ start = time.time()
76
+ result_image, decomp1, decomp2 = model(img_tensor)
77
+ #result_image = model(img_tensor)
78
+ stop = time.time()
79
+ print('Image:{}/{}, CNN Runtime:{:.4f}'.format(iteration, total_image_number, (stop - start)))
80
+ test_time += stop - start
81
+ print('Average Runtime:{:.4f}'.format(test_time / float(iteration)))
82
+ result_image = result_image + 0.5
83
+ out_file_name = out_path + '/' + file + '/' + img_name
84
+ # optionally save image
85
+ torchvision.utils.save_image(result_image, out_file_name)
86
+
87
+ # optionally print gpu usage
88
+ #print_max_gpu_usage()
89
+ #torch.cuda.reset_max_memory_allocated()
predict_HIDE_test_results.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import print_function
2
+ import numpy as np
3
+ import torch
4
+ import cv2
5
+ import yaml
6
+ import os
7
+ from torch.autograd import Variable
8
+ from models.networks import get_generator
9
+ import torchvision
10
+ import time
11
+ import argparse
12
+
13
+ def get_args():
14
+ parser = argparse.ArgumentParser('Test an image')
15
+ parser.add_argument('--job_name', default='xyscannet',
16
+ type=str, help='current job s name')
17
+ return parser.parse_args()
18
+
19
+ if __name__ == '__main__':
20
+ args = get_args()
21
+ with open(os.path.join('config/', args.job_name, 'config_stage2.yaml')) as cfg: # change the yaml file to config_pretrained if ablation
22
+ #with open(os.path.join('config/', args.job_name, 'config_stage2.yaml')) as cfg: # change the yaml file to config_pretrained if ablation
23
+ config = yaml.safe_load(cfg)
24
+ blur_path = '/scratch/user/hanzhou1996/datasets/deblur/HIDE/test/testA/'
25
+ out_path = os.path.join('results', args.job_name, 'images_hide')
26
+ weights_path = os.path.join('results', args.job_name, 'models', 'best_XYScanNet_stage2.pth') # change the model name to test different phases: final/best
27
+ if not os.path.isdir(out_path):
28
+ os.mkdir(out_path)
29
+ model = get_generator(config['model'])
30
+ model.load_state_dict(torch.load(weights_path))
31
+ model = model.cuda()
32
+ test_time = 0
33
+ iteration = 0
34
+ total_image_number = 2025
35
+
36
+ # warm up
37
+ warm_up = 0
38
+ print('Hardware warm-up')
39
+ for img_name in os.listdir(blur_path):
40
+ warm_up += 1
41
+ img = cv2.imread(blur_path + '/' + img_name)
42
+ img_tensor = torch.from_numpy(np.transpose(img / 255, (2, 0, 1)).astype('float32')) - 0.5
43
+ with torch.no_grad():
44
+ img_tensor = Variable(img_tensor.unsqueeze(0)).cuda()
45
+ result_image, decomp1, decomp2 = model(img_tensor)
46
+ #result_image = model(img_tensor)
47
+ if warm_up == 20:
48
+ break
49
+ break
50
+
51
+ for img_name in os.listdir(blur_path):
52
+ img = cv2.imread(blur_path + '/' + img_name)
53
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
54
+ img_tensor = torch.from_numpy(np.transpose(img / 255, (2, 0, 1)).astype('float32')) - 0.5
55
+ with torch.no_grad():
56
+ iteration += 1
57
+ img_tensor = Variable(img_tensor.unsqueeze(0)).cuda()
58
+
59
+ start = time.time()
60
+ result_image, decomp1, decomp2 = model(img_tensor)
61
+ #result_image = model(img_tensor)
62
+ stop = time.time()
63
+
64
+ print('Image:{}/{}, CNN Runtime:{:.4f}'.format(iteration, total_image_number, (stop - start)))
65
+ test_time += stop - start
66
+ print('Average Runtime:{:.4f}'.format(test_time / float(iteration)))
67
+ result_image = result_image + 0.5
68
+ out_file_name = out_path + '/' + img_name
69
+ torchvision.utils.save_image(result_image, out_file_name)
predict_RWBI_test_results.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import print_function
2
+ import numpy as np
3
+ import torch
4
+ import cv2
5
+ import yaml
6
+ import os
7
+ from torch.autograd import Variable
8
+ from models.networks import get_generator
9
+ import torchvision
10
+ import time
11
+ import torch.nn.functional as F
12
+ import argparse
13
+
14
+ def get_args():
15
+ parser = argparse.ArgumentParser('Test an image')
16
+ parser.add_argument('--job_name', default='xyscannet',
17
+ type=str, help='current job s name')
18
+ return parser.parse_args()
19
+
20
+ if __name__ == '__main__':
21
+ args = get_args()
22
+ with open(os.path.join('config/', args.job_name, 'config_pretrained.yaml')) as cfg: # change the yaml file to config_pretrained if ablation
23
+ #with open(os.path.join('config/', args.job_name, 'config_stage2.yaml')) as cfg: # change the yaml file to config_pretrained if ablation
24
+ config = yaml.safe_load(cfg)
25
+ blur_path = '/scratch/user/hanzhou1996/datasets/deblur/RWBI/test/testA/'
26
+ out_path = os.path.join('results', args.job_name, 'images_rwbi')
27
+ weights_path = os.path.join('results', args.job_name, 'models', 'best_XYScanNet_stage2.pth') # change the model name to test different phases: final/best
28
+ if not os.path.isdir(out_path):
29
+ os.mkdir(out_path)
30
+ model = get_generator(config['model'])
31
+ model.load_state_dict(torch.load(weights_path))
32
+ model = model.cuda()
33
+ test_time = 0
34
+ iteration = 0
35
+ total_image_number = 1000
36
+
37
+ # warm up
38
+ warm_up = 0
39
+ print('Hardware warm-up')
40
+ for img_name in os.listdir(blur_path):
41
+ warm_up += 1
42
+ img = cv2.imread(blur_path + '/' + img_name)
43
+ img_tensor = torch.from_numpy(np.transpose(img / 255, (2, 0, 1)).astype('float32')) - 0.5
44
+ with torch.no_grad():
45
+ img_tensor = Variable(img_tensor.unsqueeze(0)).cuda()
46
+ factor = 8
47
+ h, w = img_tensor.shape[2], img_tensor.shape[3]
48
+ H, W = ((h + factor) // factor) * factor, ((w + factor) // factor) * factor
49
+ padh = H - h if h % factor != 0 else 0
50
+ padw = W - w if w % factor != 0 else 0
51
+ img_tensor = F.pad(img_tensor, (0, padw, 0, padh), 'reflect')
52
+ H, W = img_tensor.shape[2], img_tensor.shape[3]
53
+
54
+ result_image, decomp1, decomp2 = model(img_tensor)
55
+ if warm_up == 20:
56
+ break
57
+
58
+ for file in os.listdir(blur_path):
59
+ if not os.path.isdir(out_path):
60
+ os.mkdir(out_path)
61
+ img = cv2.imread(blur_path + '/' + file)
62
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
63
+ img_tensor = torch.from_numpy(np.transpose(img / 255, (2, 0, 1)).astype('float32')) - 0.5
64
+ with torch.no_grad():
65
+ iteration += 1
66
+ img_tensor = Variable(img_tensor.unsqueeze(0)).cuda()
67
+
68
+ factor = 8
69
+ h, w = img_tensor.shape[2], img_tensor.shape[3]
70
+ H, W = ((h + factor) // factor) * factor, ((w + factor) // factor) * factor
71
+ padh = H - h if h % factor != 0 else 0
72
+ padw = W - w if w % factor != 0 else 0
73
+ img_tensor = F.pad(img_tensor, (0, padw, 0, padh), 'reflect')
74
+ H, W = img_tensor.shape[2], img_tensor.shape[3]
75
+
76
+ #with torch.autocast(device_type='cuda', dtype=torch.float16):
77
+ start = time.time()
78
+ result_image, decomp1, decomp2 = model(img_tensor)
79
+ stop = time.time()
80
+
81
+ result_image = result_image[:, :, :h, :w]
82
+
83
+ print('Image:{}/{}, CNN Runtime:{:.4f}'.format(iteration, total_image_number, (stop - start)))
84
+ test_time += stop - start
85
+ print('Average Runtime:{:.4f}'.format(test_time / float(iteration)))
86
+ result_image = result_image + 0.5
87
+ out_file_name = out_path + '/' + file
88
+ torchvision.utils.save_image(result_image, out_file_name)
predict_RealBlur_J_test_results.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import print_function
2
+ import numpy as np
3
+ import torch
4
+ import cv2
5
+ import yaml
6
+ import os
7
+ from torch.autograd import Variable
8
+ from models.networks import get_generator
9
+ import torchvision
10
+ import time
11
+ import torch.nn.functional as F
12
+ import argparse
13
+
14
+ def get_args():
15
+ parser = argparse.ArgumentParser('Test an image')
16
+ parser.add_argument('--job_name', default='fsformer_without_fs',
17
+ type=str, help='current job s name')
18
+ return parser.parse_args()
19
+
20
+ if __name__ == '__main__':
21
+ args = get_args()
22
+ with open(os.path.join('config/', args.job_name, 'config_stage2.yaml')) as cfg:
23
+ #with open(os.path.join('config/', args.job_name, 'config_pretrained.yaml')) as cfg:
24
+ config = yaml.safe_load(cfg)
25
+ blur_path = '/scratch/user/hanzhou1996/datasets/deblur/RealBlur_J/test/testA'
26
+ out_path = os.path.join('results', args.job_name, 'images_realj')
27
+ weights_path = os.path.join('results', args.job_name, 'models', 'final_XYScanNet_stage2.pth') # change the model name to test different phases: final/best final_StripMamba_pretrained.pth
28
+ if not os.path.isdir(out_path):
29
+ os.mkdir(out_path)
30
+ model = get_generator(config['model'])
31
+ model.load_state_dict(torch.load(weights_path))
32
+ model = model.cuda()
33
+ test_time = 0
34
+ iteration = 0
35
+ total_image_number = 980
36
+
37
+ # warm up
38
+ warm_up = 0
39
+ print('Hardware warm-up')
40
+ for file in os.listdir(blur_path):
41
+ #if not os.path.isdir(out_path + '/' + file):
42
+ # os.mkdir(out_path + '/' + file)
43
+ img_name = file
44
+ # for img_name in os.listdir(blur_path + '/' + file):
45
+ warm_up += 1
46
+ img = cv2.imread(blur_path + '/' + file)
47
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
48
+ img_tensor = torch.from_numpy(np.transpose(img / 255, (2, 0, 1)).astype('float32')) - 0.5
49
+ with torch.no_grad():
50
+ img_tensor = Variable(img_tensor.unsqueeze(0)).cuda()
51
+ factor = 8
52
+ h, w = img_tensor.shape[2], img_tensor.shape[3]
53
+ H, W = ((h + factor) // factor) * factor, ((w + factor) // factor) * factor
54
+ padh = H - h if h % factor != 0 else 0
55
+ padw = W - w if w % factor != 0 else 0
56
+ img_tensor = F.pad(img_tensor, (0, padw, 0, padh), 'reflect')
57
+ result_image, decomp1, decomp2 = model(img_tensor)
58
+ #result_image = model(img_tensor)
59
+ if warm_up == 20:
60
+ break
61
+ break
62
+
63
+ for file in os.listdir(blur_path):
64
+ #if not os.path.isdir(out_path + '/' + file):
65
+ # os.mkdir(out_path + '/' + file)
66
+ img_name = file
67
+ # for img_name in os.listdir(blur_path + '/' + file):
68
+ img = cv2.imread(blur_path + '/' + file)
69
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
70
+ img_tensor = torch.from_numpy(np.transpose(img / 255, (2, 0, 1)).astype('float32')) - 0.5
71
+ with torch.no_grad():
72
+ iteration += 1
73
+ img_tensor = Variable(img_tensor.unsqueeze(0)).cuda()
74
+
75
+ factor = 8
76
+ h, w = img_tensor.shape[2], img_tensor.shape[3]
77
+ H, W = ((h + factor) // factor) * factor, ((w + factor) // factor) * factor
78
+ padh = H - h if h % factor != 0 else 0
79
+ padw = W - w if w % factor != 0 else 0
80
+ img_tensor = F.pad(img_tensor, (0, padw, 0, padh), 'reflect')
81
+ H, W = img_tensor.shape[2], img_tensor.shape[3]
82
+
83
+ start = time.time()
84
+ _output, decomp1, decomp2 = model(img_tensor)
85
+ #_output = model(img_tensor)
86
+ stop = time.time()
87
+
88
+ result_image = _output[:, :, :h, :w]
89
+ result_image = torch.clamp(result_image, -0.5, 0.5)
90
+ result_image = result_image + 0.5
91
+
92
+ test_time += stop - start
93
+ print('Image:{}/{}, CNN Runtime:{:.4f}'.format(iteration, total_image_number, (stop - start)))
94
+ print('Average Runtime:{:.4f}'.format(test_time / float(iteration)))
95
+ out_file_name = out_path + '/' + img_name
96
+ torchvision.utils.save_image(result_image, out_file_name)
97
+
predict_RealBlur_R_test_results.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import print_function
2
+ import argparse
3
+ import numpy as np
4
+ import torch
5
+ import cv2
6
+ import yaml
7
+ import os
8
+ from torch.autograd import Variable
9
+ from models.networks import get_generator
10
+ import torchvision
11
+ import time
12
+ import torch.nn.functional as F
13
+
14
+ def get_args():
15
+ parser = argparse.ArgumentParser('Test an image')
16
+ parser.add_argument('--job_name', default='xyscannet',
17
+ type=str, help='current job s name')
18
+ return parser.parse_args()
19
+
20
+ if __name__ == '__main__':
21
+ args = get_args()
22
+ with open(os.path.join('config/', args.job_name, 'config_stage2.yaml')) as cfg:
23
+ config = yaml.safe_load(cfg)
24
+ blur_path = '/scratch/user/hanzhou1996/datasets/deblur/RealBlur_R/test/testA'
25
+ out_path = os.path.join('results', args.job_name, 'images_realr')
26
+ weights_path = os.path.join('results', args.job_name, 'models', 'final_XYScanNet_stage2.pth') # change the model name to test different phases: final/best
27
+ if not os.path.isdir(out_path):
28
+ os.mkdir(out_path)
29
+ model = get_generator(config['model'])
30
+ model.load_state_dict(torch.load(weights_path))
31
+ model = model.cuda()
32
+ test_time = 0
33
+ iteration = 0
34
+ total_image_number = 980
35
+
36
+ # warm up
37
+ warm_up = 0
38
+ print('Hardware warm-up')
39
+ for file in os.listdir(blur_path):
40
+ #if not os.path.isdir(out_path + '/' + file):
41
+ # os.mkdir(out_path + '/' + file)
42
+ #for img_name in os.listdir(blur_path + '/' + file):
43
+ img_name = file
44
+ warm_up += 1
45
+ img = cv2.imread(blur_path + '/' + file)
46
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
47
+ img_tensor = torch.from_numpy(np.transpose(img / 255, (2, 0, 1)).astype('float32')) - 0.5
48
+ with torch.no_grad():
49
+ img_tensor = Variable(img_tensor.unsqueeze(0)).cuda()
50
+ factor = 8
51
+ h, w = img_tensor.shape[2], img_tensor.shape[3]
52
+ H, W = ((h + factor) // factor) * factor, ((w + factor) // factor) * factor
53
+ padh = H - h if h % factor != 0 else 0
54
+ padw = W - w if w % factor != 0 else 0
55
+ img_tensor = F.pad(img_tensor, (0, padw, 0, padh), 'reflect')
56
+ result_image, decomp1, decomp2 = model(img_tensor)
57
+ #result_image = model(img_tensor)
58
+ if warm_up == 20:
59
+ break
60
+ break
61
+
62
+ for file in os.listdir(blur_path):
63
+ #if not os.path.isdir(out_path + '/' + file):
64
+ # os.mkdir(out_path + '/' + file)
65
+ #for img_name in os.listdir(blur_path + '/' + file):
66
+ img_name = file
67
+ img = cv2.imread(blur_path + '/' + file)
68
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
69
+ img_tensor = torch.from_numpy(np.transpose(img / 255, (2, 0, 1)).astype('float32')) - 0.5
70
+ with torch.no_grad():
71
+ iteration += 1
72
+ img_tensor = Variable(img_tensor.unsqueeze(0)).cuda()
73
+
74
+ factor = 8
75
+ h, w = img_tensor.shape[2], img_tensor.shape[3]
76
+ H, W = ((h + factor) // factor) * factor, ((w + factor) // factor) * factor
77
+ padh = H - h if h % factor != 0 else 0
78
+ padw = W - w if w % factor != 0 else 0
79
+ img_tensor = F.pad(img_tensor, (0, padw, 0, padh), 'reflect')
80
+ H, W = img_tensor.shape[2], img_tensor.shape[3]
81
+
82
+ start = time.time()
83
+ _output, decomp1, decomp2 = model(img_tensor)
84
+ #_output = model(img_tensor)
85
+ stop = time.time()
86
+
87
+ result_image = _output[:, :, :h, :w]
88
+ result_image = torch.clamp(result_image, -0.5, 0.5)
89
+ result_image = result_image + 0.5
90
+
91
+ test_time += stop - start
92
+ print('Image:{}/{}, CNN Runtime:{:.4f}'.format(iteration, total_image_number, (stop - start)))
93
+ print('Average Runtime:{:.4f}'.format(test_time / float(iteration)))
94
+ out_file_name = out_path + '/' + img_name
95
+ torchvision.utils.save_image(result_image, out_file_name)
96
+
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio==4.44.1
2
+ spaces
3
+ torch==2.1.2
4
+ torchvision==0.16.2
5
+ transformers==4.46.3
6
+ einops==0.8.1
7
+ PyYAML==6.0.2
8
+ opencv-python-headless==4.10.0.84
9
+ numpy==1.26.4
10
+ pillow==10.4.0
11
+ mamba-ssm==2.2.2
results/xyscannetp_gopro/models/best_XYScanNet_stage2.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:da56c2c8ccb0c7cfc86c81431e9b7c2681c2109ca056d56d9454ba4aeb6c07e0
3
+ size 254328477
schedulers.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ from torch.optim import lr_scheduler
4
+
5
+
6
+ class WarmRestart(lr_scheduler.CosineAnnealingLR):
7
+ """This class implements Stochastic Gradient Descent with Warm Restarts(SGDR): https://arxiv.org/abs/1608.03983.
8
+
9
+ Set the learning rate of each parameter group using a cosine annealing schedule, When last_epoch=-1, sets initial lr as lr.
10
+ This can't support scheduler.step(epoch). please keep epoch=None.
11
+ """
12
+
13
+ def __init__(self, optimizer, T_max=30, T_mult=1, eta_min=0, last_epoch=-1):
14
+ """implements SGDR
15
+
16
+ Parameters:
17
+ ----------
18
+ T_max : int
19
+ Maximum number of epochs.
20
+ T_mult : int
21
+ Multiplicative factor of T_max.
22
+ eta_min : int
23
+ Minimum learning rate. Default: 0.
24
+ last_epoch : int
25
+ The index of last epoch. Default: -1.
26
+ """
27
+ self.T_mult = T_mult
28
+ super().__init__(optimizer, T_max, eta_min, last_epoch)
29
+
30
+ def get_lr(self):
31
+ if self.last_epoch == self.T_max:
32
+ self.last_epoch = 0
33
+ self.T_max *= self.T_mult
34
+ return [self.eta_min + (base_lr - self.eta_min) * (1 + math.cos(math.pi * self.last_epoch / self.T_max)) / 2 for
35
+ base_lr in self.base_lrs]
36
+
37
+
38
+ class LinearDecay(lr_scheduler._LRScheduler):
39
+ """This class implements LinearDecay
40
+
41
+ """
42
+
43
+ def __init__(self, optimizer, num_epochs, start_epoch=0, min_lr=0, last_epoch=-1):
44
+ """implements LinearDecay
45
+
46
+ Parameters:
47
+ ----------
48
+
49
+ """
50
+ self.num_epochs = num_epochs
51
+ self.start_epoch = start_epoch
52
+ self.min_lr = min_lr
53
+ super().__init__(optimizer, last_epoch)
54
+
55
+ def get_lr(self):
56
+ if self.last_epoch < self.start_epoch:
57
+ return self.base_lrs
58
+ return [base_lr - ((base_lr - self.min_lr) / self.num_epochs) * (self.last_epoch - self.start_epoch) for
59
+ base_lr in self.base_lrs]
train_XYScanNet_stage1.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from functools import partial
3
+ import os
4
+ import cv2
5
+ import torch
6
+ import torch.optim as optim
7
+ import tqdm
8
+ import yaml
9
+ from joblib import cpu_count
10
+ from torch.utils.data import DataLoader
11
+ import random
12
+ from dataset import PairedDataset
13
+ from metric_counter import MetricCounter
14
+
15
+ from models.losses import get_loss
16
+ from models.models import get_model
17
+ from models.networks import get_nets
18
+ from util import util
19
+ import numpy as np
20
+ from torch.optim.lr_scheduler import CosineAnnealingLR
21
+ cv2.setNumThreads(0)
22
+
23
+ import argparse
24
+ parser = argparse.ArgumentParser(description='Image motion deblurring evaluation on GoPro/HIDE')
25
+ parser.add_argument('--job_name', default='xyscannet',
26
+ type=str, help='current job s name')
27
+ args = parser.parse_args()
28
+
29
+ class Trainer:
30
+ def __init__(self, config, train: DataLoader, val: DataLoader):
31
+ self.config = config
32
+ self.train_dataset = train
33
+ self.val_dataset = val
34
+ self.metric_counter = MetricCounter(config['experiment_desc'])
35
+
36
+ def train(self):
37
+ self._init_params()
38
+ start_epoch = 0
39
+ print("The current job is: ", args.job_name)
40
+ model_dir = os.path.join('results/', args.job_name, 'models')
41
+ util.mkdir(model_dir)
42
+ if os.path.exists(os.path.join(model_dir, 'last_XYScanNet_stage1.pth')):
43
+ print('resume learning')
44
+ training_state = (torch.load(os.path.join(model_dir, 'last_XYScanNet_stage1.pth')))
45
+ start_epoch = training_state['epoch'] + 1
46
+ new_weight = self.netG.state_dict()
47
+ new_weight.update(training_state['model_state'])
48
+ self.netG.load_state_dict(new_weight)
49
+ new_optimizer = self.optimizer_G.state_dict()
50
+ new_optimizer.update(training_state['optimizer_state'])
51
+ self.optimizer_G.load_state_dict(new_optimizer)
52
+ new_scheduler = self.scheduler_G.state_dict()
53
+ new_scheduler.update(training_state['scheduler_state'])
54
+ self.scheduler_G.load_state_dict(new_scheduler)
55
+
56
+
57
+ for epoch in range(start_epoch, config['num_epochs']):
58
+ self._run_epoch(epoch)
59
+ if epoch % 30 == 0 or epoch == (config['num_epochs']-1):
60
+ self._validate(epoch)
61
+ self.scheduler_G.step()
62
+
63
+ scheduler_state = self.scheduler_G.state_dict()
64
+ training_state = {'epoch': epoch, 'model_state': self.netG.state_dict(),
65
+ 'scheduler_state': scheduler_state, 'optimizer_state': self.optimizer_G.state_dict()}
66
+ if self.metric_counter.update_best_model():
67
+ torch.save(training_state['model_state'],
68
+ os.path.join(model_dir, 'best_{}.pth'.format(self.config['experiment_desc'])))
69
+
70
+ if epoch % 300 == 0:
71
+ torch.save(training_state,
72
+ os.path.join(model_dir, 'last_{}_{}.pth'.format(self.config['experiment_desc'], epoch)))
73
+
74
+ if epoch == (config['num_epochs']-1):
75
+ torch.save(training_state['model_state'],
76
+ os.path.join(model_dir, 'final_{}.pth'.format(self.config['experiment_desc'])))
77
+
78
+ torch.save(training_state,
79
+ os.path.join(model_dir, 'last_{}.pth'.format(self.config['experiment_desc'])))
80
+
81
+ logging.debug("Experiment Name: %s, Epoch: %d, Loss: %s" % (
82
+ self.config['experiment_desc'], epoch, self.metric_counter.loss_message()))
83
+
84
+ def _run_epoch(self, epoch):
85
+ self.metric_counter.clear()
86
+ for param_group in self.optimizer_G.param_groups:
87
+ lr = param_group['lr']
88
+
89
+ epoch_size = config.get('train_batches_per_epoch') or len(self.train_dataset)
90
+ tq = tqdm.tqdm(self.train_dataset)
91
+ tq.set_description('Epoch {}, lr {}'.format(epoch, lr))
92
+ i = 0
93
+ for data in tq:
94
+ inputs, targets = self.model.get_input(data)
95
+ outputs, decomp1, decomp2 = self.netG(inputs)
96
+ self.optimizer_G.zero_grad()
97
+ loss_G = self.criterionG(outputs, targets, inputs)
98
+ loss_G.backward()
99
+ self.optimizer_G.step()
100
+ self.metric_counter.add_losses(loss_G.item())
101
+ curr_psnr, curr_ssim, img_for_vis = self.model.get_images_and_metrics(inputs, outputs, targets)
102
+ self.metric_counter.add_metrics(curr_psnr, curr_ssim)
103
+ tq.set_postfix(loss=self.metric_counter.loss_message())
104
+ if not i:
105
+ self.metric_counter.add_image(img_for_vis, tag='train')
106
+ i += 1
107
+ if i > len(self.train_dataset):
108
+ break
109
+ tq.close()
110
+ self.metric_counter.write_to_tensorboard(epoch)
111
+
112
+ def _validate(self, epoch):
113
+ self.metric_counter.clear()
114
+ epoch_size = config.get('val_batches_per_epoch') or len(self.val_dataset)
115
+ tq = tqdm.tqdm(self.val_dataset)
116
+ tq.set_description('Validation')
117
+ i = 0
118
+ for data in tq:
119
+ with torch.no_grad():
120
+ inputs, targets = self.model.get_input(data)
121
+ outputs, decomp1, decomp2 = self.netG(inputs)
122
+ loss_G = self.criterionG(outputs, targets, inputs)
123
+ self.metric_counter.add_losses(loss_G.item())
124
+ curr_psnr, curr_ssim, img_for_vis = self.model.get_images_and_metrics(inputs, outputs, targets)
125
+ self.metric_counter.add_metrics(curr_psnr, curr_ssim)
126
+ if not i:
127
+ self.metric_counter.add_image(img_for_vis, tag='val')
128
+ i += 1
129
+ if i > len(self.train_dataset):
130
+ break
131
+ tq.close()
132
+ self.metric_counter.write_to_tensorboard(epoch, validation=True)
133
+
134
+
135
+ def _get_optim(self, params):
136
+ if self.config['optimizer']['name'] == 'adam':
137
+ optimizer = optim.Adam(params, lr=self.config['optimizer']['lr'])
138
+ elif self.config['optimizer']['name'] == 'adamw':
139
+ optimizer = optim.AdamW(params, lr=0.001, weight_decay=0.001, betas=(0.9,0.9))
140
+ else:
141
+ raise ValueError("Optimizer [%s] not recognized." % self.config['optimizer']['name'])
142
+ return optimizer
143
+
144
+ def _get_scheduler(self, optimizer):
145
+ if self.config['scheduler']['name'] == 'cosine':
146
+ scheduler = CosineAnnealingLR(optimizer, T_max=self.config['num_epochs'], eta_min=self.config['scheduler']['min_lr'])
147
+ else:
148
+ raise ValueError("Scheduler [%s] not recognized." % self.config['scheduler']['name'])
149
+ return scheduler
150
+
151
+ def _init_params(self):
152
+ self.criterionG = get_loss(self.config['model'])
153
+ self.netG = get_nets(self.config['model'])
154
+ self.netG.cuda()
155
+ self.model = get_model(self.config['model'])
156
+ self.optimizer_G = self._get_optim(filter(lambda p: p.requires_grad, self.netG.parameters()))
157
+ self.scheduler_G = self._get_scheduler(self.optimizer_G)
158
+
159
+
160
+ if __name__ == '__main__':
161
+ with open(os.path.join('config/', args.job_name, 'config_stage1.yaml'), 'r') as f:
162
+ config = yaml.safe_load(f)
163
+
164
+ # setup
165
+ torch.backends.cudnn.enabled = True
166
+ torch.backends.cudnn.benchmark = True
167
+
168
+ # set random seed
169
+ seed = 666
170
+ torch.manual_seed(seed)
171
+ torch.cuda.manual_seed(seed)
172
+ random.seed(seed)
173
+ np.random.seed(seed)
174
+
175
+ batch_size = config.pop('batch_size')
176
+ get_dataloader = partial(DataLoader, batch_size=batch_size, num_workers=cpu_count(), shuffle=True, drop_last=False)
177
+
178
+ datasets = map(config.pop, ('train', 'val'))
179
+ datasets = map(PairedDataset.from_config, datasets)
180
+ train, val = map(get_dataloader, datasets)
181
+ trainer = Trainer(config, train=train, val=val)
182
+ trainer.train()
train_XYScanNet_stage2.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from functools import partial
3
+ import os
4
+ import cv2
5
+ import torch
6
+ import torch.optim as optim
7
+ import tqdm
8
+ import yaml
9
+ from joblib import cpu_count
10
+ from torch.utils.data import DataLoader
11
+ import random
12
+ from dataset import PairedDataset
13
+ from metric_counter import MetricCounter
14
+ from models.losses import get_loss
15
+ from models.models import get_model
16
+ from models.networks import get_nets
17
+ import numpy as np
18
+ from torch.optim.lr_scheduler import CosineAnnealingLR
19
+ cv2.setNumThreads(0)
20
+
21
+ import argparse
22
+ parser = argparse.ArgumentParser(description='Image motion deblurring evaluation on GoPro/HIDE')
23
+ parser.add_argument('--job_name', default='xyscannet',
24
+ type=str, help='current job s name')
25
+ args = parser.parse_args()
26
+
27
+ class Trainer:
28
+ def __init__(self, config, train: DataLoader, val: DataLoader):
29
+ self.config = config
30
+ self.train_dataset = train
31
+ self.val_dataset = val
32
+ self.metric_counter = MetricCounter(config['experiment_desc'])
33
+
34
+ def train(self):
35
+ self._init_params()
36
+ start_epoch = 0
37
+ print("The current job is: ", args.job_name)
38
+ model_dir = os.path.join('results/', args.job_name, 'models')
39
+ if os.path.exists(os.path.join(model_dir, 'last_XYScanNet_stage2.pth')):
40
+ print('resume learning')
41
+ training_state = (torch.load(os.path.join(model_dir, 'last_XYScanNet_stage2.pth')))
42
+ start_epoch = training_state['epoch'] + 1
43
+ new_weight = self.netG.state_dict()
44
+ new_weight.update(training_state['model_state'])
45
+ self.netG.load_state_dict(new_weight)
46
+ new_optimizer = self.optimizer_G.state_dict()
47
+ new_optimizer.update(training_state['optimizer_state'])
48
+ self.optimizer_G.load_state_dict(new_optimizer)
49
+ new_scheduler = self.scheduler_G.state_dict()
50
+ new_scheduler.update(training_state['scheduler_state'])
51
+ self.scheduler_G.load_state_dict(new_scheduler)
52
+ else:
53
+ print('load_weights_stage1')
54
+ training_state = (torch.load(os.path.join(model_dir, 'final_XYScanNet_stage1.pth')))
55
+ new_weight = self.netG.state_dict()
56
+ new_weight.update(training_state)
57
+ self.netG.load_state_dict(new_weight)
58
+
59
+
60
+ for epoch in range(start_epoch, config['num_epochs']):
61
+ self._run_epoch(epoch)
62
+ if epoch % 30 == 0 or epoch == (config['num_epochs']-1):
63
+ self._validate(epoch)
64
+ self.scheduler_G.step()
65
+
66
+ scheduler_state = self.scheduler_G.state_dict()
67
+ training_state = {'epoch': epoch, 'model_state': self.netG.state_dict(),
68
+ 'scheduler_state': scheduler_state, 'optimizer_state': self.optimizer_G.state_dict()}
69
+ if self.metric_counter.update_best_model():
70
+ torch.save(training_state['model_state'],
71
+ os.path.join(model_dir, 'best_{}.pth'.format(self.config['experiment_desc'])))
72
+ if epoch % 200 == 0:
73
+ torch.save(training_state,
74
+ os.path.join(model_dir, 'last_{}_{}.pth'.format(self.config['experiment_desc'], epoch)))
75
+
76
+ if epoch == (config['num_epochs']-1):
77
+ torch.save(training_state['model_state'],
78
+ os.path.join(model_dir, 'final_{}.pth'.format(self.config['experiment_desc'])))
79
+
80
+ torch.save(training_state,
81
+ os.path.join(model_dir, 'last_{}.pth'.format(self.config['experiment_desc'])))
82
+ logging.debug("Experiment Name: %s, Epoch: %d, Loss: %s" % (
83
+ self.config['experiment_desc'], epoch, self.metric_counter.loss_message()))
84
+
85
+ def _run_epoch(self, epoch):
86
+ self.metric_counter.clear()
87
+ for param_group in self.optimizer_G.param_groups:
88
+ lr = param_group['lr']
89
+
90
+ epoch_size = config.get('train_batches_per_epoch') or len(self.train_dataset)
91
+ tq = tqdm.tqdm(self.train_dataset)
92
+ tq.set_description('Epoch {}, lr {}'.format(epoch, lr))
93
+ i = 0
94
+ for data in tq:
95
+ inputs, targets = self.model.get_input(data)
96
+ outputs, decomp1, decomp2 = self.netG(inputs)
97
+ #outputs = self.netG(inputs)
98
+ self.optimizer_G.zero_grad()
99
+ loss_G = self.criterionG(outputs, targets, inputs)
100
+ loss_G.backward()
101
+ self.optimizer_G.step()
102
+ self.metric_counter.add_losses(loss_G.item())
103
+ curr_psnr, curr_ssim, img_for_vis = self.model.get_images_and_metrics(inputs, outputs, targets)
104
+ self.metric_counter.add_metrics(curr_psnr, curr_ssim)
105
+ tq.set_postfix(loss=self.metric_counter.loss_message())
106
+ if not i:
107
+ self.metric_counter.add_image(img_for_vis, tag='train')
108
+ i += 1
109
+ if i > len(self.train_dataset):
110
+ break
111
+ tq.close()
112
+ self.metric_counter.write_to_tensorboard(epoch)
113
+
114
+ def _validate(self, epoch):
115
+ self.metric_counter.clear()
116
+ epoch_size = config.get('val_batches_per_epoch') or len(self.val_dataset)
117
+ tq = tqdm.tqdm(self.val_dataset)
118
+ tq.set_description('Validation')
119
+ i = 0
120
+ for data in tq:
121
+ with torch.no_grad():
122
+ inputs, targets = self.model.get_input(data)
123
+ outputs, decomp1, decomp2 = self.netG(inputs)
124
+ #outputs = self.netG(inputs)
125
+ loss_G = self.criterionG(outputs, targets, inputs)
126
+ self.metric_counter.add_losses(loss_G.item())
127
+ curr_psnr, curr_ssim, img_for_vis = self.model.get_images_and_metrics(inputs, outputs, targets)
128
+ self.metric_counter.add_metrics(curr_psnr, curr_ssim)
129
+ if not i:
130
+ self.metric_counter.add_image(img_for_vis, tag='val')
131
+ i += 1
132
+ if i > len(self.train_dataset):
133
+ break
134
+ tq.close()
135
+ self.metric_counter.write_to_tensorboard(epoch, validation=True)
136
+
137
+
138
+ def _get_optim(self, params):
139
+ if self.config['optimizer']['name'] == 'adam':
140
+ optimizer = optim.Adam(params, lr=self.config['optimizer']['lr'])
141
+ else:
142
+ raise ValueError("Optimizer [%s] not recognized." % self.config['optimizer']['name'])
143
+ return optimizer
144
+
145
+ def _get_scheduler(self, optimizer):
146
+ if self.config['scheduler']['name'] == 'cosine':
147
+ scheduler = CosineAnnealingLR(optimizer, T_max=self.config['num_epochs'], eta_min=self.config['scheduler']['min_lr'])
148
+ else:
149
+ raise ValueError("Scheduler [%s] not recognized." % self.config['scheduler']['name'])
150
+ return scheduler
151
+
152
+ def _init_params(self):
153
+ self.criterionG = get_loss(self.config['model'])
154
+ self.netG = get_nets(self.config['model'])
155
+ self.netG.cuda()
156
+ self.model = get_model(self.config['model'])
157
+ self.optimizer_G = self._get_optim(filter(lambda p: p.requires_grad, self.netG.parameters()))
158
+ self.scheduler_G = self._get_scheduler(self.optimizer_G)
159
+
160
+
161
+ if __name__ == '__main__':
162
+ with open(os.path.join('config/', args.job_name, 'config_stage2.yaml'), 'r') as f:
163
+ config = yaml.safe_load(f)
164
+ # setup
165
+ torch.backends.cudnn.enabled = True
166
+ torch.backends.cudnn.benchmark = True
167
+
168
+ # set random seed
169
+ seed = 666
170
+ torch.manual_seed(seed)
171
+ torch.cuda.manual_seed(seed)
172
+ random.seed(seed)
173
+ np.random.seed(seed)
174
+
175
+ batch_size = config.pop('batch_size')
176
+ get_dataloader = partial(DataLoader, batch_size=batch_size, num_workers=cpu_count(), shuffle=True, drop_last=False)
177
+
178
+ datasets = map(config.pop, ('train', 'val'))
179
+ datasets = map(PairedDataset.from_config, datasets)
180
+ train, val = map(get_dataloader, datasets)
181
+ trainer = Trainer(config, train=train, val=val)
182
+ trainer.train()
util/__init__.py ADDED
File without changes
util/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (143 Bytes). View file
 
util/__pycache__/__init__.cpython-36.pyc ADDED
Binary file (126 Bytes). View file
 
util/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (145 Bytes). View file