nox star removal

The gallery above shows some astrophotographs that have had their stars automatically removed using nox. nox is a neural network, trained exclusively on an artifical training dataset, which I have created. nox is otherwise based on StarNet as it shares almost the same architecture, and was trained using the same losses.

I can think of three good reasons to remove the stars from an astrophotograph:

  1. For art: Stars can be a distraction, and so starless images emphasise the nebulosity and are usually highly dramatic and pleasing to look at.
  2. For processing: Stars can be removed to allow nebulosity and other faint details to be enhanced without bloating the stars, which can be added back in later.
  3. For accurate star masks: Isolating stars allows the creation of star masks without any areas of nebulosity creeping in. (Accurate star masks are useful for various reasons, including, selective processing to enhance faint details; assessing and removing background noise and light pollution; color calibration; etc.)

Unfortunately, manually removing stars is difficult, highly labor intensive, and subjective decisions need to be made about how to fill the areas previously occupied by stars.

StarNet

Mikita (Nikita) Misiura developed StarNet, which is a deep convolutional neural network to automatically remove stars from astrophotographs. Mikita developed the model’s architecture, the training technique with adversarial and perceptual losses, and spent many hundreds of hours manually removing stars from astrophotographs to create a training dataset.

Architecture

The StarNet architecture is a convolutional encoder-decoder residual network. This is a powerful architecture for image noise reduction, super resolution, inpainting, and so on, and is well-studied in computer vision and image processing (see here, for example). It is also an architecture type that I have used before for noise reduction in astrophotography, and so I have some familarity and experience with it.

In short, a convolutional neural network (CNN) uses successive filter layers that slide along inputs to provide responses known as feature maps.

Encoder-decoder CNNs first encode the input as a “hidden” representation by passing it through a series of convolutional (filter) layers. They then decode the hidden representation to generate the output by passing it through another series of deconvolutional layers. Encoder-decoder networks essentially translate one image to another, passing it through a hidden numerical representations along the way.

Residual networks (ResNets) are a specific type of neural network in which there is a direct connection from the input to the output (and sometimes between other layers too). This is called a “skip connection”. In a sense, they compute the difference (the “residual”) between the input and the desired output. This simple idea allows more complex (deeper) neural networks with many layers to be trained to solve complex tasks. Before ResNets were introduced (in 2015), adding more layers made networks more difficult to train and degraded their accuracy.

The StarNet generator network has 16 convolutional/deconvolutional layers. The generator is the network that translates a starry image into a starless one. In addition, StarNet has a discriminator network (with 10 layers), which learns to distinguish “fake” starless images (created by the generator) from “true” starless (“ground truth”) images (ones created manually, for instance).

Mikita has made the architecture and training code publicly available on github, here. It is licensed under the MIT License, which grants permission, free of charge, to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, provided the copyright notice accompanies copies or substantial portions of the Software.

Training

StarNet was trained using adversarial and perceptual losses. Adversarial losses come from the discriminator. The discriminator is trained in parallel with the generator to identify fake starless images. Every training iteration, both a generated starless image (from the generator) and a ground truth starless version of the same starry image are fed into the discriminator. It is more complex than this, but the principle is that if the discriminator believes the generated image is fake, this generates a high generator “cost”, and vice versa. As training proceeds, attempting to reduce the cost, the discriminator gets better at identifying fake images, but the generator also gets better at creating convincing fake images. The generator and discriminator face off against each other in an arms race until the generator is so good that the discriminator is effectively just guessing and is only correct 50% of the time on average. It is a highly effective training technique.

In addition, StarNet training uses perceptual losses. The idea with perceptual losses is that generated images should look natural and tell us if two images look like each other to a human, even if there are mathematical differences (like an offset, or small-scale variation). StarNet achieves this by using the discriminator to compare the generated and ground truth images at different scales, weighting the different scales according to importance to perception.

nox

In my recent work, I have developed a version of StarNet, which I am calling nox. It is based on a different training dataset, and uses a very similar architecture to StarNet; however, I use Layer Normalization between layers instead of Batch Normalization since it improves training rate and stability for me. I have independently written my own training and inference code (it is largely an adaptation of a previous machine learning project), only copying the StarNet architecture and the specific code for calculating adversarial and perceptual losses from the StarNet code on github.

I wanted to achieve my own star removal capability to feed into my other astrophotography processing work, with my own weights trained from initialization, and with enhanced capability to remove stars in cases where StarNet did not see many training images (e.g., different optical systems, images with dense star fields which are hard to manually remove stars from).

Artificial training dataset

It is commonly said in the field of machine learning that 25% of the work is in developing the architecture, and the remaining 75% is in creating the dataset. I’m sure Mikita would agree…

I just didn’t have the patience to remove stars manually to create a dataset, and figured it would be easier to add artificial but representative stars onto any random ground truth image. I added artificial stars to images from the SIDD and RENOIR datasets (which I happened to have to hand). These form training pairs in which the starry image is the input, and the original image with no stars is the desired output.

I used the Astropy and Photutils Python libraries and simulated various optical system parameters (arcseconds per pixel, FWHM, diffraction spikes), and atmospheric qualities (profile parameters, stability, noise). I also used the trend of the local luminosity function to get a realistic spread of star brightness versus frequency (brightness versus the number density logarithm is a triangular distribution).

Here are some grayscale examples (I am still working on the color version):

nox training

I generated 2500 random 512×512 training images pairs, and split the dataset 80:20 for training and validation respectively. I trained the network from initialization on an Nvidia Jetson AGX Orin (which includes a 2048-core Nvidia Ampere architecture GPU with 64 Tensor cores) using a batch size of 4 (batch size limited by memory). Training used a generator training rate of 10^{-4} throughout. Here are the loss and PSNR histories after 213 epochs, towards the end of the training:

nox performance

Since the network takes 512×512 tiles as input, images first must be decomposed into tiles before they can be destarred. To avoid edge artefacts, my code discards a 192-pixel wide border around every tile after destarring, so therefore the input tiles must overlap. This is equivalent to a stride of 128 pixels. After destarring, the tiles with discarded borders are reassembled.

Some inference examples are shown above using real astrophotographs (that were unseen during training since training uses artifical data only). Click an image to expand, view and scroll in Gallery Mode with captions.

Next steps

This Python implementation is really just a development code and not suitable for deployment. nox will be a tool in my astrophotography processing sofware, Hera. The model and weights will be converted into an OpenCV deep neural network (DNN) object and compiled into the Hera binaries. Within the Hera software, it will be available as a free-to-use image processing tool in its own right, and will also be built into other tools, for example, the star masking tool.

Before doing this, however, I will also develop a color dataset to train a color version of nox to remove stars from color images.

Python code

Generating training pairs

This code generates 2500 image pairs, dividing the work between 10 CPUs running in parallel. You can adjust these numbers to your preference. You will also need to have downloaded the SIDD and RENOIR datasets and installed the relevant packages as per the imports.

import matplotlib.pyplot as plt
from matplotlib import interactive
interactive(True)
from astropy.modeling.models import Moffat2D
from photutils.datasets import (make_model_sources_image, make_random_models_table, make_gaussian_sources_image,
make_random_gaussians_table, apply_poisson_noise)
import numpy as np
import sys
import pathlib
import os
import glob
import cv2
from threading import Thread

save_dir = os.getcwd() + os.sep + "nox data" + os.sep
cpus = 10
imcount = 2500
start = 0
preview_pairs = False # generate side-by-side image pairs for previewing

def get_images_paths(root_dir_ssid, root_dir_mi):
    starless_lst = []
    original_lst = []
    
    # SSID dataset image paths
    root = pathlib.Path(root_dir_ssid)
    img_paths = list(root.rglob("*.PNG*"))
    img_paths_lst = [str(path) for path in img_paths]
    for p in img_paths_lst:
        img_type = p.split(os.sep)[-1].split('_')[-3]
        if img_type == "NOISY":
            original_lst.append(p)
            starless_lst.append(p)
        elif img_type == "GT":
            original_lst.append(p)
            starless_lst.append(p)

    # RENOIR dataset image paths
    for p in [x[0] for x in os.walk(root_dir_mi)]:
        noisyImgs = glob.glob(p + os.sep + '*Noisy.bmp')
        original_lst.extend(noisyImgs)
        starless_lst.extend(noisyImgs)
        refImag = glob.glob(p + os.sep + '*Reference.bmp')
        original_lst.extend(refImag)
        starless_lst.extend(refImag)

    original_array = np.asarray(original_lst)
    starless_array = np.asarray(starless_lst)
    return original_array, starless_array
    
def get_random_crop(image, label, crop_height, crop_width):

    max_x = image.shape[1] - crop_width
    max_y = image.shape[0] - crop_height

    x = np.random.randint(0, max_x)
    y = np.random.randint(0, max_y)

    image_crop = image[y: y + crop_height, x: x + crop_width]
    label_crop = label[y: y + crop_height, x: x + crop_width]

    return image_crop, label_crop

def midtones(x, s, m):
    y = np.clip(x, a_min = s, a_max = None) # set all less than s to s
    y = (y - s)/(1. - s); # image from 0 to 1 after clipping
    y = (m - 1.)*y/((2.*m - 1.)*y - m); # midtones transfer function (m = 0.5 for unity)
    return y
    
def add_stars(image, label):
    starModel = Moffat2D()
    shape = image.shape[0:2]

    app = np.random.uniform(low = 0.8, high = 2.2) # arcseconds per pixel, constant per image
    FWHM0 = np.random.triangular(left = 1.8, mode = 2.5, right = 3.2)/app # random FWHM in pixels, constant per image
    m = np.random.uniform(low = 0.01, high = 0.1) # midtones stretch value
    gain = np.random.uniform(low = 0.1, high = 2.)
    starblur = np.random.uniform(low = 0., high = 3.)
    
    # middle-sized stars (large enough for a properly resolved source function, not large enough for bloat) with some midtones transformation
    n_src = 300
    param_ranges = {'amplitude': [1., 1.], 'x_0': [0, shape[1]], 'y_0': [0, shape[0]], 'gamma': [1., 1.], 'alpha': [2.0, 10]} # for range of alpha, see https://academic.oup.com/mnras/article/328/3/977/1247204 ('beta' in reference)
    sources = make_random_models_table(n_src, param_ranges, seed = None)
    sources['amplitude'] = np.multiply(np.power(10., np.random.triangular(left = 0., mode = 0., right = 1., size = n_src)) - 1., 3./9) # replace amplitudes with triangular distribution of magnitudes following trend of local luminosity function, resulting in range 0 to 3
    sources['gamma'] = FWHM0/2./np.sqrt(pow(2., 1./sources['alpha']) - 1.) # calcuate gamma according to alpha and FWHM
    stars1 = make_model_sources_image(shape, starModel, sources)
    stars1 = apply_poisson_noise(stars1*255*gain, seed=None)/255/gain # add Poisson noise
    stars1 = midtones(stars1, 0., m)
    stars1 = cv2.GaussianBlur(stars1, (0, 0), starblur)
    theta0 = np.random.uniform(0, np.pi/2)
    yes_no_spikes = np.random.randint(2)
    if yes_no_spikes:
        sources['amplitude'] = np.clip(sources['amplitude'], 0., 1.) # cap amplitude at 1. for adding spikes
        stars1 = add_spikes(stars1, sources, theta0)
    
    # large stars - large enough for full-well capacity bleeding / bloating
    n_large = 20
    param_ranges = {'amplitude': [1., 1.], 'x_0': [0, shape[1]], 'y_0': [0, shape[0]], 'gamma': [1., 1.], 'alpha': [2.0, 10]} # for range of alpha, see https://academic.oup.com/mnras/article/328/3/977/1247204 ('beta' in reference)
    sources = make_random_models_table(n_large, param_ranges)
    sources['amplitude'] = np.multiply(np.power(10., np.random.triangular(left = 0., mode = 0., right = 1., size = n_large)) - 1., 1./9) + 1. # replace amplitudes with triangular distribution of magnitudes following trend of local luminosity function, resulting in range 1 to 2
    FWHM = np.random.triangular(left = FWHM0, mode = FWHM0, right = 15., size = n_large) # use range of wide FWHMs to simulate full-well capacity bleeding / bloating
    sources['gamma'] = FWHM/2./np.sqrt(pow(2., 1./sources['alpha']) - 1.) # calcuate gamma according to alpha and FWHM
    stars2 = make_model_sources_image(shape, starModel, sources)
    stars2 = apply_poisson_noise(stars2*255*gain, seed=None)/255/gain # add Poisson noise
    stars2 = midtones(stars2, 0., m)
    stars2 = cv2.GaussianBlur(stars2, (0, 0), starblur)
    if yes_no_spikes: stars2 = add_spikes(stars2, sources, theta0)
    
    # small and dim Gaussian stars - subject to a bit more 'wobble'
    n_small = 2500
    stretch_range = 1.1 # range of star 'smearing'
    param_ranges = {'amplitude': [1., 1.], 'x_mean': [0, shape[1]], 'y_mean': [0, shape[0]], 'x_stddev': [1., 1.], 'y_stddev': [1., 1.], 'theta': [0., np.pi]}
    sources = make_random_gaussians_table(n_small, param_ranges, seed = None)
    sources['amplitude'] = np.multiply(np.power(10., np.random.triangular(left = 0., mode = 0., right = 1., size = n_small)) - 1., 0.1/9) # replace amplitudes with triangular distribution of magnitudes following trend of local luminosity function, resulting in range 0 to 0.1
    xstretch = np.random.uniform(low = 1./stretch_range, high = stretch_range) # generate uniform distribution of FWHM x stretch factors
    ystretch = np.random.uniform(low = 1./stretch_range, high = stretch_range) # generate uniform distribution of FWHM x stretch factors
    sources['x_stddev'] = np.multiply(FWHM0/2.35, xstretch)
    sources['y_stddev'] = np.multiply(FWHM0/2.35, ystretch)
    stars3 = make_gaussian_sources_image(shape, sources)
    stars3 = apply_poisson_noise(stars3*255*gain, seed=None)/255/gain # add Poisson noise
    stars3 = midtones(stars3, 0., m)
    stars3 = cv2.GaussianBlur(stars3, (0, 0), starblur)
    
    # combine stars to image
    stars = stars1 + stars2 + stars3
    image = 1. - (1. - stars)*(1. - image) # screen blend stars

    image = np.clip(image, 0., 1.)
    return image, label

def add_spikes(stars, sources, theta0): # add simulated diffraction spikes
    FWHM = 2.*sources['gamma']*np.sqrt(pow(2., 1./sources['alpha']) - 1.)
    n_src = len(sources['amplitude'])
    shape = stars.shape
    sources['x_mean'] = sources['x_0']
    del sources['x_0']
    sources['y_mean'] = sources['y_0']
    del sources['y_0']
    del sources['gamma']
    del sources['alpha']
    sources['amplitude'] = np.multiply(sources['amplitude'], np.random.uniform(low = 0.8, high = 1.0, size = n_src))
    sources['x_stddev'] = np.multiply(np.multiply(FWHM, sources['amplitude']/15), np.random.uniform(low = 1., high = 1.2, size = n_src))
    sources['y_stddev'] = np.multiply(np.multiply(FWHM, sources['amplitude']), np.random.uniform(low = 1., high = 1.5, size = n_src))
    sources['theta'] = np.ones(n_src)*theta0
    spikes = make_gaussian_sources_image(shape, sources)
    sources['theta'] = sources['theta'] + np.pi/2
    spikes += make_gaussian_sources_image(shape, sources)
    spikes = np.clip(spikes, 0., 1.)
    spikes = cv2.GaussianBlur(spikes, (0, 0), np.random.uniform(low = 0.5, high = 1.0))
    return 1. - (1. - stars)*(1. - spikes) # screen blend spikes

def create_images(images, labels, i_from, i_to):
    for i in range(i_from, i_to):
        print('image %d'%i)
        if os.path.exists(save_dir + 'x%d.png'%i) and os.path.exists(save_dir + 'y%d.png'%i): continue
        
        image = cv2.imread(images[i%len(images)])
        label = cv2.imread(labels[i%len(labels)])
        
        image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)/255.
        label = cv2.cvtColor(label, cv2.COLOR_BGR2GRAY)/255.
        
        image, label = get_random_crop(image, label, 512, 512)
        image, label = add_stars(image, label)
        
        cv2.imwrite(save_dir + 'x%d.png'%i, cv2.cvtColor(np.float32(image)*255., cv2.COLOR_GRAY2BGR))
        cv2.imwrite(save_dir + 'y%d.png'%i, cv2.cvtColor(np.float32(label)*255., cv2.COLOR_GRAY2BGR))
        
        if preview_pairs:
            # plot training images
            plt.close()
            fig, ax = plt.subplots(1, 2, sharex = True, figsize=(21, 21))
            ax[0].imshow(cv2.cvtColor(np.float32(image), cv2.COLOR_GRAY2BGR))
            ax[0].set_axis_off()
            ax[1].imshow(cv2.cvtColor(np.float32(label), cv2.COLOR_GRAY2BGR))
            ax[1].set_axis_off()
            fig.tight_layout()
            plt.savefig(save_dir + '%i.png'%i, bbox_inches = 'tight')

if __name__ == "__main__":

    images, labels = get_images_paths("../datasets/SIDD_Small_sRGB_Only", "../datasets/RENOIR")
    
    # shuffle
    training_data = list(zip(images, labels))
    np.random.shuffle(training_data)
    images, labels = zip(*training_data)
    
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    
    if not preview_pairs:
        images_per_cpu = int((imcount - start)/cpus)
        threads = []
        for cpu in range(cpus):
            thread = Thread(target = create_images, args = (images, labels, start + cpu*images_per_cpu, start + (cpu + 1)*images_per_cpu))
            threads.append(thread)
            threads[-1].start()
        
        for i in range(len(threads)):
            threads[i].join()
            print('joined thread %i'%i)

    else: create_images(images, labels, 0, imcount) # matplotlib only runs in main thread

Training and inference

This code can be run in two modes:

  1. Training mode (“python ./nox.py train”)
  2. Inference mode (“python ./nox.py infer input.tiff”) and it will output a file “starless.png” in the current working directory.

You will need to have installed Tensorflow and various other Python packages, as per the imports. Note that the input and output image formats do not necessarily need to be tiff and png, respectively.

import sys
import pathlib
import os
import numpy as np
import cv2
import skimage
from sklearn.model_selection import train_test_split
import glob
import matplotlib.pyplot as plt
from matplotlib import interactive
interactive(True)
import pickle
import gc
import csv
import time
import tensorflow as tf
import pandas as pd

# enable dynamic memory allocation
gpu_devices = tf.config.experimental.list_physical_devices('GPU')
for device in gpu_devices:
    tf.config.experimental.set_memory_growth(device, True)

# global variables
epochs = 1500
n_channels = 1
patch_size = 512
stride = 128
border = int((patch_size - stride)/2)
print('border = ' + str(border))

# training parameters
BATCH_SIZE = 4 # number of training samples to work through before the model’s parameters are updated
ema = 0.9999 # exponential moving average: keep 99.99% of the existing state and 0.01% of the new state
lr = 1e-4 # learning rate
save_backups = True
validation = True

def generator():
    tf.keras.backend.clear_session() # release global state of old models and layers
    layers = []
    filters = [64, 128, 256, 512, 512, 512, 512, 512, 512, 512, 512, 512, 256, 128, 64] # filter layers 0 - 14
    input = tf.keras.layers.Input(shape = (patch_size, patch_size, n_channels), name = "gen_input_image")
    for i in range(1 + len(filters)):
        if i == 0: # layer 0
            convolved = tf.keras.layers.Conv2D(filters[0], kernel_size = 4, strides = (2, 2), padding = "same", kernel_initializer = tf.initializers.GlorotUniform())(input)
            layers.append(convolved)
        elif 1 <= i <= 7: # convolution layers
            rectified = tf.keras.layers.LeakyReLU(alpha = 0.2)(layers[-1])
            convolved = tf.keras.layers.Conv2D(filters[i], kernel_size = 4, strides = (2, 2), padding = "same", kernel_initializer = tf.initializers.GlorotUniform())(rectified)
            normalized = tf.keras.layers.LayerNormalization()(convolved)
            layers.append(normalized)
        elif 8 <= i <= 14: # deconvolution layers
            if i == 8:
                rectified = tf.keras.layers.ReLU()(layers[-1])
            else:
                concatenated = tf.concat([layers[-1], layers[15 - i]], axis = 3)
                rectified = tf.keras.layers.ReLU()(concatenated)
            deconvolved = tf.keras.layers.Conv2DTranspose(filters[i], kernel_size = 4, strides = (2, 2), padding = "same", kernel_initializer = tf.initializers.GlorotUniform())(rectified)
            normalized = tf.keras.layers.LayerNormalization()(deconvolved)
            layers.append(normalized)
        else: # layer 15
            concatenated = tf.concat([layers[-1], layers[0]], axis = 3)
            rectified = tf.keras.layers.ReLU()(concatenated)
            deconvolved = tf.keras.layers.Conv2DTranspose(n_channels, kernel_size = 4, strides = (2, 2), padding = "same", kernel_initializer = tf.initializers.GlorotUniform())(rectified)
            rectified = tf.keras.layers.ReLU()(deconvolved)
            output = tf.math.subtract(input, rectified)
    return tf.keras.Model(inputs = input, outputs = output, name = "generator")

def discriminator():
    layers = []
    filters = [32, 64, 64, 128, 128, 256, 256, 256, 8]
    input = tf.keras.layers.Input(shape = (patch_size, patch_size, n_channels), name = "dis_input_image")
    for i in range(1 + len(filters)):
        if i % 2 == 1 or i == 8:
            padding = "valid"
            strides = (2, 2)
        else:
            padding = "same"
            strides = (1, 1)
        if i == 0: # layer 0
            convolved = tf.keras.layers.Conv2D(filters[i], kernel_size = 3, strides = strides, padding = padding)(input)
            rectified = tf.keras.layers.LeakyReLU(alpha = 0.2)(convolved)
            layers.append(rectified)
        elif 1 <= i <= 8: # convolution layers
            convolved = tf.keras.layers.Conv2D(filters[i], kernel_size = 3, strides = strides, padding = padding)(layers[-1])
            normalized = tf.keras.layers.LayerNormalization()(convolved)
            rectified = tf.keras.layers.LeakyReLU(alpha = 0.2)(normalized)
            layers.append(rectified)
        else: # layer 9
            dense = tf.keras.layers.Dense(1)(layers[-1])
            sigmoid = tf.nn.sigmoid(dense)
            layers.append(sigmoid)
    output = [layers[0], layers[1], layers[2], layers[3], layers[4], layers[5], layers[6], layers[7], layers[-1]]
    return tf.keras.Model(inputs = input, outputs = output, name = "discriminator")
    
def get_images_paths(root_dir):
    starless_lst = []
    original_lst = []
    
    root = pathlib.Path(root_dir)
    x_paths = sorted(list(root.rglob("x*.png*")))
    y_paths = sorted(list(root.rglob("y*.png*")))
    
    x_paths_lst = [str(path) for path in x_paths]
    for p in x_paths_lst:
        original_lst.append(p)
        
    y_paths_lst = [str(path) for path in y_paths]
    for p in y_paths_lst:
        starless_lst.append(p)

    original_array = np.asarray(original_lst)
    starless_array = np.asarray(starless_lst)
    return original_array, starless_array
    
def up_down_flip(image, label):
    if tf.random.uniform(shape = [], minval = 0, maxval = 2, dtype = tf.int32) == 1:
        image = tf.image.flip_up_down(image)
        label = tf.image.flip_up_down(label)
    return image, label

def left_right_flip(image, label):
    if tf.random.uniform(shape = [], minval = 0, maxval = 2, dtype = tf.int32) == 1:
        image = tf.image.flip_left_right(image)
        label = tf.image.flip_left_right(label)
    return image, label

def rotate_90(image, label):
    rand_value = tf.random.uniform(shape = [], minval = 0, maxval = 4, dtype = tf.int32)
    image = tf.image.rot90(image, rand_value)
    label = tf.image.rot90(label, rand_value)
    return image, label
    
def adjust_brightness(image, label):
    if tf.random.uniform(shape = []) < 0.7:
        m = tf.math.minimum(image, label)
        offset = tf.random.uniform(shape = [])*0.25 - tf.random.uniform(shape = [])*m
        image = image + offset*(1. - image)
        label = label + offset*(1. - label)
        image = tf.clip_by_value(image, 0. , 1.)
        label = tf.clip_by_value(label, 0. , 1.)
    return image, label

def process_path(path_original, path_starless):
    img_original = tf.io.read_file(path_original)
    img_original = tf.io.decode_image(img_original, dtype = tf.dtypes.float32)
    img_starless = tf.io.read_file(path_starless)
    img_starless = tf.io.decode_image(img_starless, dtype = tf.dtypes.float32)
    if n_channels == 1:
        img_original = tf.image.rgb_to_grayscale(img_original)
        img_starless = tf.image.rgb_to_grayscale(img_starless)
    return img_original, img_starless

def random_crop(image, label):
    combined = tf.concat([image, label], axis=2)
    combined_crop = tf.image.random_crop(combined, [patch_size, patch_size, n_channels*2])
    return (combined_crop[:, :, :n_channels], combined_crop[:, :, n_channels:])

def data_generator(X, y, batch_size, augmentations = None):
    dataset = tf.data.Dataset.from_tensor_slices((X, y))
    dataset = dataset.shuffle(buffer_size = 100, reshuffle_each_iteration = True)
    dataset = dataset.map(process_path, num_parallel_calls = tf.data.AUTOTUNE)
    # dataset = dataset.map(random_crop, num_parallel_calls = tf.data.AUTOTUNE) # uncomment if dataset tiles are oversized
    
    if augmentations:
        for f in augmentations:
            dataset = dataset.map(f, num_parallel_calls = tf.data.AUTOTUNE)
    
    dataset = dataset.repeat() # re-initialize the dataset as soon as all the entries have been read
    dataset = dataset.batch(batch_size = batch_size, drop_remainder = True) # make batches out of dataset
    dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE) # allows later elements to be prepared while the current element is being processed, improving latency and throughput, at the cost of using additional memory to store prefetched elements
    return dataset
    
def PSNR(y_true, y_pred):
    return tf.image.psnr(y_pred, y_true, max_val = 1.)

def inference_single_tile(model, original_image):
    input_image = np.expand_dims(original_image, axis = 0)
    predicted_image = (model.predict(input_image*2. - 1.) + 1.)/2.
    return predicted_image[0]
    
def inference_batch_tiles(model, original_images):
    predicted_images = (model.predict(original_images*2. - 1., batch_size = BATCH_SIZE) + 1.)/2.
    return predicted_images
    
def infer_image(model, original_image, border = 0):
    predicted_image = np.zeros(original_image.shape)
    sizeX = original_image.shape[1]
    sizeY = original_image.shape[0]
    
    # add border around original image
    original_image = cv2.copyMakeBorder(original_image, border, border, border, border, cv2.BORDER_REFLECT)
    sizeX = original_image.shape[1]
    sizeY = original_image.shape[0]
    
    # split original image into tiles
    fromRow = 0 # row pixel index
    fromCol = 0 # col pixel index
    nRows = 0
    nCols = 0
    original_lst = []
    finalRow = False
    while True: # loop rows
        nRows += 1
        toRow = fromRow + patch_size
        if toRow >= sizeY:
            fromRow = sizeY - patch_size
            toRow = sizeY
            finalRow = True
        
        finalCol = False
        while True: # loop cols
            if finalRow:
                nCols += 1 # count columns on final row only
            
            toCol = fromCol + patch_size
            if toCol >= sizeX:
                fromCol = sizeX - patch_size
                toCol = sizeX
                finalCol = True
            
            original_roi = original_image[fromRow:toRow, fromCol:toCol]
            original_lst.append(original_roi)
            
            if finalCol:
                fromCol = 0
                break
            else:
                fromCol += (patch_size - 2*border)
                
        if finalRow:
            break
        else:
            fromRow += (patch_size - 2*border)
    
    # infer batch of tiles
    original_rois = np.array(original_lst)
    predicted_rois = inference_batch_tiles(model, original_rois)
    
    # build up predicted image from predicted tiles
    fromRow = 0 # row pixel index
    fromCol = 0 # col pixel index
    count = 0
    for i in range(0, nRows):
        toRow = fromRow + patch_size - 2*border
        if i == nRows - 1:
            fromRow = (sizeY - 2*border) - (patch_size - 2*border)
            toRow = (sizeY - 2*border)
            finalRow = True
        
        for j in range(0, nCols):
            toCol = fromCol + patch_size - 2*border
            
            if j == nCols - 1:
                fromCol = (sizeX - 2*border) - (patch_size - 2*border)
                toCol = (sizeX - 2*border)
                finalCol = True
            
            predicted_image[fromRow:toRow, fromCol:toCol] = predicted_rois[count][border:patch_size - border, border:patch_size - border]
            count = count + 1
            
            if j == nCols - 1:
                fromCol = 0
                break
            else:
                fromCol += (patch_size - 2*border)
        
        fromRow += (patch_size - 2*border)
    
    np.clip(predicted_image, 0., 1., predicted_image) # clip 0 to 1
    
    return predicted_image
    
def infer(file = ''):
    tf.keras.backend.clear_session()
    print('building generator...')
    model = generator()
    model.summary()
    
    if os.path.exists(os.getcwd() + os.sep + 'generator_gray.h5'):
        print('loading weights...')
        model.load_weights('generator_gray.h5')
    else: return
    
    # process image from disk
    if len(file) > 0 and os.path.exists(file):
        original_image = cv2.imread(file)
        if len(original_image.shape) == 3:
            original_image = cv2.cvtColor(original_image, cv2.COLOR_BGR2RGB)
            if n_channels == 1:
                original_image = cv2.cvtColor(original_image, cv2.COLOR_BGR2GRAY)
                original_image = original_image[:, :, np.newaxis]
        if original_image.dtype == 'uint8': original_image = original_image/255.
        cv2.imwrite('starry.png', np.float32(original_image)*255.)
        start_time = time.time()
        starless_image = infer_image(model, original_image, border = border) # infer
        print("--- %s seconds ---" % (time.time() - start_time))
        cv2.imwrite('starless.png', np.float32(starless_image)*255.)
        
    else: # process image from dataset
        original_array_paths, starless_array_paths = get_images_paths(os.getcwd() + os.sep + "nox data") # training and test image paths
        i = np.random.randint(0, original_array_paths.shape[0]) # get a random noisy image index
        original_image = cv2.imread(original_array_paths[i])
        if n_channels == 1:
            original_image = cv2.cvtColor(original_image, cv2.COLOR_BGR2GRAY)
            original_image = original_image[:, :, np.newaxis]
        cv2.imwrite('input.png', original_image)
        if n_channels == 3: original_image = cv2.cvtColor(original_image, cv2.COLOR_BGR2RGB)
        if original_image.dtype == 'uint8': original_image = original_image/255.
        gt_image = cv2.imread(starless_array_paths[i])
        if n_channels == 1:
            gt_image = cv2.cvtColor(gt_image, cv2.COLOR_BGR2GRAY)
            gt_image = gt_image[:, :, np.newaxis]
        cv2.imwrite('gt.png', gt_image)
        if n_channels == 3: gt_image = cv2.cvtColor(gt_image, cv2.COLOR_BGR2RGB)
        if gt_image.dtype == 'uint8': gt_image = gt_image/255.
        predicted_image = infer_image(model, original_image, border) # infer
        cv2.imwrite('output.png', cv2.cvtColor(np.float32(predicted_image)*255., cv2.COLOR_RGB2BGR))
        psnr1 = PSNR(original_image, gt_image).numpy()
        psnr2 = PSNR(predicted_image, gt_image).numpy()
        print('PSNR(original, gt) = ' + str(psnr1))
        print('PSNR(starless, gt) = ' + str(psnr2))

def train():
    tf.keras.backend.clear_session()
    
    # training and test image paths
    original_array_paths, starless_array_paths = get_images_paths(os.getcwd() + os.sep + "nox data")
    
    if validation:
        original_train_paths, original_test_paths, starless_train_paths, starless_test_paths = train_test_split(original_array_paths, starless_array_paths, test_size = 0.2, random_state = 42)
        print('number of original images for training = ' + str(original_train_paths.shape))
        print('number of original images for testing = ' + str(original_test_paths.shape))
        print('number of starless images for training = ' + str(starless_train_paths.shape))
        print('number of starless images for testing = ' + str(starless_test_paths.shape))
    else:
        original_train_paths = original_array_paths
        starless_train_paths = starless_array_paths
        print('number of original images for training = ' + str(original_train_paths.shape))
        print('number of starless images for training = ' + str(starless_train_paths.shape))
    
    # data generators
    augmentation_lst = [up_down_flip, left_right_flip, rotate_90, adjust_brightness]
    image_generator_train = data_generator(X = original_train_paths, y = starless_train_paths, batch_size = BATCH_SIZE, augmentations = augmentation_lst)
    image_generator_test = []
    if validation:
        image_generator_test = data_generator(X = original_test_paths, y = starless_test_paths, batch_size = BATCH_SIZE)
    
    # nets
    print('building generator...')
    G = generator()
    G.summary()
    print('building discriminator...')
    D = discriminator()
    D.summary()
    
    # optimizers
    gen_optimizer = tf.optimizers.Adam(learning_rate = lr)
    dis_optimizer = tf.optimizers.Adam(learning_rate = lr/4)
    
    # number of epochs and steps
    print('number of training epochs = ' + str(epochs))
    steps_per_epoch_train = int(len(original_train_paths)/BATCH_SIZE)
    print('steps per training epoch = ' + str(steps_per_epoch_train))
    steps_per_epoch_validation = 1
    if validation:
        steps_per_epoch_validation = int(len(original_test_paths)/BATCH_SIZE)
        print('steps per validation epoch = ' + str(steps_per_epoch_validation))
    
    hist_train = {}
    hist_val = {}
    
    # load weights
    e0 = -1
    if os.path.exists(os.getcwd() + os.sep + 'generator_gray.h5') \
        and os.path.exists(os.getcwd() + os.sep + 'discriminator_gray.h5') \
        and os.path.exists(os.getcwd() + os.sep + 'history.pkl'):
        print('loading weights...')
        G.load_weights('generator_gray.h5')
        D.load_weights('discriminator_gray.h5')
        with open('history.pkl', 'rb') as f:
            print('loading training history...')
            if validation:
                e0, hist_train, hist_val = pickle.load(f)
                print("\repoch %d: loss %f PSNR %f    val: loss %f PSNR %f                    \n"%(e0, hist_train['total'][-1], hist_train['psnr'][-1], hist_val['total'][-1], hist_val['psnr'][-1]), end = '')
            else: 
                e0, hist_train = pickle.load(f)
                print("\repoch %d: loss %f PSNR %f                    \n"%(e0, hist_train['total'][-1], hist_train['psnr'][-1]), end = '')

    # training epochs with validation
    # START: This portion of code is Copyright (c) 2018-2019 Nikita Misiura and used here under the MIT License
    for e in range(e0 + 1, epochs, 1):
        # training batches
        for i in range(steps_per_epoch_train):
            x, y = image_generator_train.take(1).as_numpy_iterator().next()
            x = x*2. - 1. # rescale fom -1 to 1
            y = y*2. - 1.
            
            if i>0:
                print("\repoch %d: it %d / %d loss %f PSNR %f    "%(e, i + 1, steps_per_epoch_train, hist_train['total'][-1], hist_train['psnr'][-1]), end = '')
            else:
                print("\repoch %d: it %d / %d "%(e, i + 1, steps_per_epoch_train), end = '')
            
            with tf.GradientTape() as gen_tape, tf.GradientTape() as dis_tape:
                gen_output = G(x, training = True)
                
                p1_real, p2_real, p3_real, p4_real, p5_real, p6_real, p7_real, p8_real, predict_real = D(y, training = True)
                p1_fake, p2_fake, p3_fake, p4_fake, p5_fake, p6_fake, p7_fake, p8_fake, predict_fake = D(gen_output, training = True)
                
                d = {}
                
                dis_loss = tf.reduce_mean(-(tf.math.log(predict_real + 1E-8)+tf.math.log(1. - predict_fake + 1E-8)))
                d['dis_loss'] = dis_loss
                
                gen_loss_GAN = tf.reduce_mean(-tf.math.log(predict_fake + 1E-8))
                d['gen_loss_GAN'] = gen_loss_GAN
                
                gen_p1 = tf.reduce_mean(tf.abs(p1_fake - p1_real))
                d['gen_p1'] = gen_p1
                
                gen_p2 = tf.reduce_mean(tf.abs(p2_fake - p2_real))
                d['gen_p2'] = gen_p2
                
                gen_p3 = tf.reduce_mean(tf.abs(p3_fake - p3_real))
                d['gen_p3'] = gen_p3
                
                gen_p4 = tf.reduce_mean(tf.abs(p4_fake - p4_real))
                d['gen_p4'] = gen_p4
                
                gen_p5 = tf.reduce_mean(tf.abs(p5_fake - p5_real))
                d['gen_p5'] = gen_p5
                
                gen_p6 = tf.reduce_mean(tf.abs(p6_fake - p6_real))
                d['gen_p6'] = gen_p6
                
                gen_p7 = tf.reduce_mean(tf.abs(p7_fake - p7_real))
                d['gen_p7'] = gen_p7
                
                gen_p8 = tf.reduce_mean(tf.abs(p8_fake - p8_real))
                d['gen_p8'] = gen_p8
                
                gen_L1 = tf.reduce_mean(tf.abs(y - gen_output))
                d['gen_L1'] = gen_L1*100.
                
                d['psnr'] = tf.reduce_mean(PSNR(y, gen_output))
                
                gen_loss = gen_loss_GAN*0.1 + gen_p1*0.1 + gen_p2*10. + gen_p3*10. + gen_p4*10. + gen_p5*10. + gen_p6*10. + gen_p7*10. + gen_p8*10. + gen_L1*100.
                d['total'] = gen_loss
                
                for k in d:
                    if k in hist_train.keys():
                        hist_train[k].append(d[k]*(1. - ema) + hist_train[k][-1]*ema)
                    else:
                        hist_train[k] = [d[k]]
                
                gen_grads = gen_tape.gradient(gen_loss, G.trainable_variables)
                gen_optimizer.apply_gradients(zip(gen_grads, G.trainable_variables))
                
                dis_grads = dis_tape.gradient(dis_loss, D.trainable_variables)
                dis_optimizer.apply_gradients(zip(dis_grads, D.trainable_variables))
        
        # validation batches
        if validation:
            for i in range(steps_per_epoch_validation):
                x, y = image_generator_test.take(1).as_numpy_iterator().next()
                x = x*2. - 1. # rescale fom -1 to 1
                y = y*2. - 1.
                
                if i > 0:
                    print("\repoch %d: loss %f PSNR %f    val: it %d / %d loss %f PSNR %f    "%(e, hist_train['total'][-1], hist_train['psnr'][-1], i+1, steps_per_epoch_validation, hist_val['total'][-1], hist_val['psnr'][-1]), end = '')
                else:
                    print("\repoch %d: loss %f PSNR %f    val: it %d / %d    "%(e, hist_train['total'][-1], hist_train['psnr'][-1], i+1, steps_per_epoch_validation), end = '')
                
                gen_output = G(x)
                
                p1_real, p2_real, p3_real, p4_real, p5_real, p6_real, p7_real, p8_real, predict_real = D(y)
                p1_fake, p2_fake, p3_fake, p4_fake, p5_fake, p6_fake, p7_fake, p8_fake, predict_fake = D(gen_output)
                
                d = {}
                
                dis_loss = tf.reduce_mean(-(tf.math.log(predict_real + 1E-8) + tf.math.log(1. - predict_fake + 1E-8)))
                d['dis_loss'] = dis_loss
                
                gen_loss_GAN = tf.reduce_mean(-tf.math.log(predict_fake + 1E-8))
                d['gen_loss_GAN'] = gen_loss_GAN
                
                gen_p1 = tf.reduce_mean(tf.abs(p1_fake - p1_real))
                d['gen_p1'] = gen_p1
                
                gen_p2 = tf.reduce_mean(tf.abs(p2_fake - p2_real))
                d['gen_p2'] = gen_p2
                
                gen_p3 = tf.reduce_mean(tf.abs(p3_fake - p3_real))
                d['gen_p3'] = gen_p3
                
                gen_p4 = tf.reduce_mean(tf.abs(p4_fake - p4_real))
                d['gen_p4'] = gen_p4
                
                gen_p5 = tf.reduce_mean(tf.abs(p5_fake - p5_real))
                d['gen_p5'] = gen_p5
                
                gen_p6 = tf.reduce_mean(tf.abs(p6_fake - p6_real))
                d['gen_p6'] = gen_p6
                
                gen_p7 = tf.reduce_mean(tf.abs(p7_fake - p7_real))
                d['gen_p7'] = gen_p7
                
                gen_p8 = tf.reduce_mean(tf.abs(p8_fake - p8_real))
                d['gen_p8'] = gen_p8
                
                gen_L1 = tf.reduce_mean(tf.abs(y - gen_output))
                d['gen_L1'] = gen_L1*100.
                
                d['psnr'] = tf.reduce_mean(PSNR(y, gen_output))
                
                gen_loss = gen_loss_GAN*0.1 + gen_p1*0.1 + gen_p2*10. + gen_p3*10. + gen_p4*10. + gen_p5*10. + gen_p6*10. + gen_p7*10. + gen_p8*10. + gen_L1*100.
                d['total'] = gen_loss
                
                for k in d:
                    if k in hist_val.keys():
                        hist_val[k].append(d[k]*(1. - ema) + hist_val[k][-1]*ema)
                    else:
                        hist_val[k] = [d[k]]
            
            print("\repoch %d: loss %f PSNR %f    val: loss %f PSNR %f                    \n"%(e, hist_train['total'][-1], hist_train['psnr'][-1], hist_val['total'][-1], hist_val['psnr'][-1]), end = '')
        else:
            print("")

        # END: This portion of code is Copyright (c) 2018-2019 Nikita Misiura and used here under the MIT License
        
        if save_backups:
            G.save_weights('generator_gray.h5')
            D.save_weights('discriminator_gray.h5')
            if validation:
                with open('history.pkl', 'wb') as f:
                    pickle.dump([e, hist_train, hist_val], f)
            else:
                with open('history.pkl', 'wb') as f:
                    pickle.dump([e, hist_train], f)
        
        # CSV logger
        with open('history.csv', 'a', newline = '') as csv:
            if validation:
                if e == 0: csv.write('epoch,loss,psnr,val_loss,val_psnr,lr_gen,lr\n')
                csv.write('%d,%f,%f,%f,%f,%f\n'%
                    (e,
                    hist_train['total'][-1].numpy(),
                    hist_train['psnr'][-1].numpy(),
                    hist_val['total'][-1].numpy(),
                    hist_val['psnr'][-1].numpy(),
                    lr))
            else:
                if e == 0: csv.write('epoch,loss,psnr,lr\n')
                csv.write('%d,%f,%f,%f\n'%
                    (e,
                    hist_train['total'][-1].numpy(),
                    hist_train['psnr'][-1].numpy(),
                    lr))
        
        # plotting
        if validation and os.path.exists('history.csv'):
            df = pd.read_csv('history.csv')
            fig, ax = plt.subplots(1, 2, sharex = True)#, figsize = (21, 21)
            line1, = ax[0].plot(df.epoch, df.loss, 'b', label = 'training')
            line2, = ax[0].plot(df.epoch, df.val_loss, 'r', label = 'validation')
            ax[0].legend(handles=[line1, line2])
            ax[0].set_xlabel('epoch', fontsize = 12)
            ax[0].set_ylabel('loss', fontsize = 12)
            line1, = ax[1].plot(df.epoch, df.psnr, 'b', label = 'training')
            line2, ax[1].plot(df.epoch, df.val_psnr, 'r', label = 'validation')
            ax[1].legend(handles = [line1, line2])
            ax[1].set_xlabel('epoch', fontsize = 12)
            ax[1].set_ylabel('psnr (dB)', fontsize = 12)
            fig.tight_layout()
            plt.savefig(os.getcwd() + os.sep + 'training.png', bbox_inches = 'tight')
        elif ~validation and os.path.exists('history.csv'):
            df = pd.read_csv('history.csv')
            fig, ax = plt.subplots(1, 2, sharex = True)#, figsize = (21, 21)
            line1, = ax[0].plot(df.epoch, df.loss, 'b', label = 'training')
            ax[0].legend(handles=[line1])
            ax[0].set_xlabel('epoch', fontsize = 12)
            ax[0].set_ylabel('loss', fontsize = 12)
            line1, = ax[1].plot(df.epoch, df.psnr, 'b', label = 'training')
            ax[1].legend(handles = [line1])
            ax[1].set_xlabel('epoch', fontsize = 12)
            ax[1].set_ylabel('psnr (dB)', fontsize = 12)
            fig.tight_layout()
            plt.savefig(os.getcwd() + os.sep + 'training.png', bbox_inches = 'tight')
            
        if validation:
            # process random image from data set
            i = np.random.randint(0, original_array_paths.shape[0]) # get a random noisy image index
            original_image = cv2.imread(original_array_paths[i])
            if n_channels == 1:
                original_image = cv2.cvtColor(original_image, cv2.COLOR_BGR2GRAY)
                original_image = original_image[:, :, np.newaxis]
            cv2.imwrite('input.png', original_image)
            if n_channels == 3: original_image = cv2.cvtColor(original_image, cv2.COLOR_BGR2RGB)
            if original_image.dtype == 'uint8': original_image = original_image/255.
            gt_image = cv2.imread(starless_array_paths[i])
            if n_channels == 1:
                gt_image = cv2.cvtColor(gt_image, cv2.COLOR_BGR2GRAY)
                gt_image = gt_image[:, :, np.newaxis]
            cv2.imwrite('gt.png', gt_image)
            if n_channels == 3: gt_image = cv2.cvtColor(gt_image, cv2.COLOR_BGR2RGB)
            if gt_image.dtype == 'uint8': gt_image = gt_image/255.
            predicted_image = infer_image(G, original_image, border = border) # infer
            cv2.imwrite('output.png', cv2.cvtColor(np.float32(predicted_image)*255., cv2.COLOR_RGB2BGR))
            psnr1 = PSNR(original_image, gt_image).numpy()
            psnr2 = PSNR(predicted_image, gt_image).numpy()
            print('PSNR(input, gt) = ' + str(psnr1))
            print('PSNR(output, gt) = ' + str(psnr2))
                
        # housekeeping
        gc.collect()
        tf.keras.backend.clear_session()

if __name__ == "__main__":
    mode = 'train' # train mode by default
    if len(sys.argv) > 1: mode = sys.argv[1]
    
    if mode == 'train': train()
    elif mode == 'infer':
        if len(sys.argv) > 2: infer(sys.argv[2]) # file name
        else: infer() # draw randomly from dataset

MIT License

The relevant sections of this code, as indicated by comments, are used under the MIT License, as follows:

Code is Copyright (c) 2018-2019 Nikita Misiura

Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

Leave a comment