Low-Light Enhancement in the Frequency Domain
Authors: Hao Chen, Zhi Jin
Abstract: Decreased visibility, intensive noise, and biased color are the common problems existing in low-light images. These visual disturbances further reduce the performance of high-level vision tasks, such as object detection, and tracking. To address this issue, some image enhancement methods have been proposed to increase the image contrast. However, most of them are implemented only in the spatial domain, which can be severely influenced by noise signals while enhancing. Hence, in this work, we propose a novel residual recurrent multi-wavelet convolutional neural network R2-MWCNN learned in the frequency domain that can simultaneously increase the image contrast and reduce noise signals well. This end-to-end trainable network utilizes a multi-level discrete wavelet transform to divide input feature maps into distinct frequencies, resulting in a better denoise impact. A channel-wise loss function is proposed to correct the color distortion for more realistic results. Extensive experiments demonstrate that our proposed R2-MWCNN outperforms the state-of-the-art methods quantitively and qualitatively.
What, Why and How
Here is a summary of the key points from the paper:
What:
- This paper proposes a new deep learning method called Residual Recurrent Multi-Wavelet Convolutional Neural Network (R2-MWCNN) for low-light image enhancement.
Why:
- Existing methods in pixel domain have limited ability to remove noise while enhancing low-light images, often resulting in noise amplification and color distortion.
- Transforming to frequency domain using wavelets allows better separation of noise from structural image information.
How:
- Uses discrete wavelet transform (DWT) and inverse DWT to extract multi-scale frequency features and replace pooling/upsampling operations.
- Proposes Multi-level Shortcut Connections (MSC) module to adaptively transmit low-level features to aid detail recovery.
- Employs a channel-wise loss to handle color distortion.
- R2-MWCNN architecture integrates illumination enhancement, color restoration and denoising in an end-to-end model.
In summary, this paper presents a frequency-domain deep learning method for low-light image enhancement that leverages wavelets and a novel network architecture to effectively restore illumination while suppressing noise and color distortion. Experiments show it outperforms state-of-the-art quantitatively and qualitatively.
Main Contributions
Here are the main contributions of this paper:
-
Proposes a novel end-to-end frequency-based architecture (R2-MWCNN) that integrates illumination enhancement, color restoration and denoising for low-light image enhancement.
-
Introduces discrete wavelet transform (DWT) and inverse DWT to exploit multi-scale frequency features, enabling better separation of noise from structural information compared to pixel domain methods.
-
Presents a Multi-level Shortcut Connections (MSC) module that adaptively selects and transmits low-level features to aid in detail recovery in deeper layers of the network.
-
Designs a channel-wise loss function to handle color distortion, a common issue in low-light enhancement.
-
Achieves state-of-the-art performance on real low-light images, outperforming previous methods quantitatively and qualitatively.
In summary, the main contributions are:
-
A new frequency-domain network architecture for end-to-end low-light enhancement
-
Use of wavelets for multi-scale frequency analysis to separate noise
-
MSC module for adaptive use of low-level features
-
Channel-wise loss to handle color distortion
-
Superior quantitative and qualitative performance over existing methods
Method Section
Here is a summary of the method section from the paper:
-
Proposes a Residual Recurrent Multi-Wavelet Convolutional Neural Network (R2-MWCNN) architecture.
-
Uses U-Net structure with contracting and expanding subnetworks.
-
In contracting subnetwork, replaces pooling with discrete wavelet transform (DWT) to transform to frequency domain and extract multi-scale features.
-
In expanding subnetwork, uses inverse DWT (IDWT) to restore information and return to pixel domain.
-
Introduces Multi-level Shortcut Connections (MSC) module that uses shortcut connections and add operations to transmit low-level features to deeper layers to aid detail recovery.
-
MSC uses attention map to select relevant features. Adaptively combines input features with outputs from MSC modules.
-
Uses two-layer DWT to generate 4 sub-bands - LL, LH, HL, HH representing frequency components.
-
Applies channel-wise attention on DWT outputs.
-
Defines loss function with 4 components:
- Pixel-wise loss using smoothed L1
- Perceptual and structural losses
- Edge loss using Sobel operator
- Channel-wise loss to handle color distortion
In summary, the key aspects of the method are:
- R2-MWCNN architecture with UNet, DWT/IDWT and MSC modules
- MSC adaptively transmits low-level features using shortcut connections
- DWT analysis and IDWT synthesis for frequency domain processing
- Channel-wise loss to correct color distortion
The architecture and loss function are designed to effectively restore illumination, suppress noise and retain details for enhanced low-light images.
High-Level Pseudo Code
Here is the high-level pseudo code for the key aspects of the method proposed in the paper:
# R2-MWCNN Architecture
def contracting_path(input_image):
features = input_image
for i in range(num_levels):
features = MSC_module(features) # see MSC pseudo code
features = Conv_block(features)
features = DWT(features) # transform to frequency domain
return features
def expanding_path(input_features):
for i in range(num_levels):
features = IDWT(input_features) # inverse transform to pixel domain
features = Conv_block(features)
features = MSC_module(features)
return features
def R2MWCNN(input_image):
cont_features = contracting_path(input_image)
exp_features = expanding_path(cont_features)
output_image = ConvBlock(exp_features)
return output_image
# MSC Module
def MSC_module(input_features):
attention_map = ComputeAttentionMap(input_features)
attended_features = attention_map * input_features
for i in range(num_layers):
residual = Conv1x1(attended_features)
input_features = Conv_block(input_features + residual)
return input_features
# Loss Function
def compute_loss(enhanced_image, ground_truth):
pixel_loss = SmoothL1Loss(enhanced_image, ground_truth)
perceptual_loss = VGGFeatureLoss(enhanced_image, ground_truth)
structural_loss = 1 - SSIM(enhanced_image, ground_truth)
edge_loss = SobelEdgeLoss(enhanced_image, ground_truth)
channelwise_loss = ChannelLoss(enhanced_image, ground_truth)
total_loss = pixel_loss + perceptual_loss + structural_loss + edge_loss + channelwise_loss
return total_loss
Detailed Pseudo Code
Here is a more detailed pseudo code to implement the key aspects of the paper:
# Imports and setup
import tensorflow as tf
import numpy as np
from tensorflow.keras.layers import *
from tensorflow.keras.models import Model
# R2-MWCNN model
class R2MWCNN(Model):
def __init__(self):
super(R2MWCNN, self).__init__()
# Contracting path
self.conv1 = Conv2D(filters=32, kernel_size=3)
self.conv2 = Conv2D(filters=64, kernel_size=3)
self.dwt1 = DWTLayer() # wavelet transform layer
self.msc1 = MSC_Module()
# Expanding path
self.idwt1 = IDWTLayer()
self.msc2 = MSC_Module()
self.tconv1 = Conv2DTranspose(filters=64, kernel_size=3)
self.tconv2 = Conv2DTranspose(filters=32, kernel_size=3)
# Output layer
self.output_layer = Conv2D(filters=3, kernel_size=3)
def call(self, x):
# Contracting path
c1 = self.conv1(x)
c2 = self.conv2(c1)
dwt1 = self.dwt1(c2)
msc1 = self.msc1(dwt1)
# Expanding path
idwt1 = self.idwt1(msc1)
msc2 = self.msc2(idwt1)
t1 = self.tconv1(msc2)
t2 = self.tconv2(t1)
# Generate enhanced image
output = self.output_layer(t2)
return output
# MSC Module
class MSC_Module(Model):
def __init__(self):
super(MSC_Module, self).__init__()
self.conv1 = Conv2D(filters=32, kernel_size=3)
self.conv2 = Conv2D(filters=32, kernel_size=3)
self.conv3 = Conv2D(filters=32, kernel_size=3)
self.attn = AttentionLayer() # attention layer
def call(self, input_features):
# Attention map
attn_map = self.attn(input_features)
attn_features =Multiply()([input_features, attn_map])
# Shortcut connections
r1 = Conv2D(filters=32, kernel_size=1)(attn_features)
x1 = Add()([input_features, r1])
c1 = self.conv1(x1)
r2 = Conv2D(filters=32, kernel_size=1)(c1)
x2 = Add()([x1, r2])
c2 = self.conv2(x2)
r3 = Conv2D(filters=32, kernel_size=1)(c2)
x3 = Add()([x2, r3])
c3 = self.conv3(x3)
return c3
# Loss function
def r2mwcnn_loss(enhanced_img, ground_truth_img):
# Pixel loss
pixel_loss = tf.reduce_mean(tf.abs(enhanced_img - ground_truth_img))
# Perceptual loss
perceptual_loss = vgg_loss(enhanced_img, ground_truth_img)
# Structural loss
structural_loss = 1 - tf.image.ssim(enhanced_img, ground_truth_img, 1.0)
# Edge loss
edge_loss = sobel_edge_loss(enhanced_img, ground_truth_img)
# Channel loss
channel_loss = channelwise_loss(enhanced_img, ground_truth_img)
# Total loss
total_loss = pixel_loss + 0.2*perceptual_loss + 5*structural_loss + 0.1*edge_loss + 2*channel_loss
return total_loss
Let me know if you would like me to explain any part of the code in more detail!