You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
217 lines
7.1 KiB
217 lines
7.1 KiB
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. |
|
# |
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
|
# you may not use this file except in compliance with the License. |
|
# You may obtain a copy of the License at |
|
# |
|
# http://www.apache.org/licenses/LICENSE-2.0 |
|
# |
|
# Unless required by applicable law or agreed to in writing, software |
|
# distributed under the License is distributed on an "AS IS" BASIS, |
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
# See the License for the specific language governing permissions and |
|
# limitations under the License. |
|
|
|
import math |
|
import random |
|
|
|
import numpy as np |
|
from scipy.stats import beta |
|
|
|
|
|
def fftfreqnd(h, w=None, z=None): |
|
""" Get bin values for discrete fourier transform of size (h, w, z) |
|
|
|
:param h: Required, first dimension size |
|
:param w: Optional, second dimension size |
|
:param z: Optional, third dimension size |
|
""" |
|
fz = fx = 0 |
|
fy = np.fft.fftfreq(h) |
|
|
|
if w is not None: |
|
fy = np.expand_dims(fy, -1) |
|
|
|
if w % 2 == 1: |
|
fx = np.fft.fftfreq(w)[:w // 2 + 2] |
|
else: |
|
fx = np.fft.fftfreq(w)[:w // 2 + 1] |
|
|
|
if z is not None: |
|
fy = np.expand_dims(fy, -1) |
|
if z % 2 == 1: |
|
fz = np.fft.fftfreq(z)[:, None] |
|
else: |
|
fz = np.fft.fftfreq(z)[:, None] |
|
|
|
return np.sqrt(fx * fx + fy * fy + fz * fz) |
|
|
|
|
|
def get_spectrum(freqs, decay_power, ch, h, w=0, z=0): |
|
""" Samples a fourier image with given size and frequencies decayed by decay power |
|
|
|
:param freqs: Bin values for the discrete fourier transform |
|
:param decay_power: Decay power for frequency decay prop 1/f**d |
|
:param ch: Number of channels for the resulting mask |
|
:param h: Required, first dimension size |
|
:param w: Optional, second dimension size |
|
:param z: Optional, third dimension size |
|
""" |
|
scale = np.ones(1) / (np.maximum(freqs, np.array([1. / max(w, h, z)])) |
|
**decay_power) |
|
|
|
param_size = [ch] + list(freqs.shape) + [2] |
|
param = np.random.randn(*param_size) |
|
|
|
scale = np.expand_dims(scale, -1)[None, :] |
|
|
|
return scale * param |
|
|
|
|
|
def make_low_freq_image(decay, shape, ch=1): |
|
""" Sample a low frequency image from fourier space |
|
|
|
:param decay_power: Decay power for frequency decay prop 1/f**d |
|
:param shape: Shape of desired mask, list up to 3 dims |
|
:param ch: Number of channels for desired mask |
|
""" |
|
freqs = fftfreqnd(*shape) |
|
spectrum = get_spectrum(freqs, decay, ch, |
|
*shape) #.reshape((1, *shape[:-1], -1)) |
|
spectrum = spectrum[:, 0] + 1j * spectrum[:, 1] |
|
mask = np.real(np.fft.irfftn(spectrum, shape)) |
|
|
|
if len(shape) == 1: |
|
mask = mask[:1, :shape[0]] |
|
if len(shape) == 2: |
|
mask = mask[:1, :shape[0], :shape[1]] |
|
if len(shape) == 3: |
|
mask = mask[:1, :shape[0], :shape[1], :shape[2]] |
|
|
|
mask = mask |
|
mask = (mask - mask.min()) |
|
mask = mask / mask.max() |
|
return mask |
|
|
|
|
|
def sample_lam(alpha, reformulate=False): |
|
""" Sample a lambda from symmetric beta distribution with given alpha |
|
|
|
:param alpha: Alpha value for beta distribution |
|
:param reformulate: If True, uses the reformulation of [1]. |
|
""" |
|
if reformulate: |
|
lam = beta.rvs(alpha + 1, alpha) |
|
else: |
|
lam = beta.rvs(alpha, alpha) |
|
|
|
return lam |
|
|
|
|
|
def binarise_mask(mask, lam, in_shape, max_soft=0.0): |
|
""" Binarises a given low frequency image such that it has mean lambda. |
|
|
|
:param mask: Low frequency image, usually the result of `make_low_freq_image` |
|
:param lam: Mean value of final mask |
|
:param in_shape: Shape of inputs |
|
:param max_soft: Softening value between 0 and 0.5 which smooths hard edges in the mask. |
|
:return: |
|
""" |
|
idx = mask.reshape(-1).argsort()[::-1] |
|
mask = mask.reshape(-1) |
|
num = math.ceil(lam * mask.size) if random.random() > 0.5 else math.floor( |
|
lam * mask.size) |
|
|
|
eff_soft = max_soft |
|
if max_soft > lam or max_soft > (1 - lam): |
|
eff_soft = min(lam, 1 - lam) |
|
|
|
soft = int(mask.size * eff_soft) |
|
num_low = int(num - soft) |
|
num_high = int(num + soft) |
|
|
|
mask[idx[:num_high]] = 1 |
|
mask[idx[num_low:]] = 0 |
|
mask[idx[num_low:num_high]] = np.linspace(1, 0, (num_high - num_low)) |
|
|
|
mask = mask.reshape((1, 1, in_shape[0], in_shape[1])) |
|
return mask |
|
|
|
|
|
def sample_mask(alpha, decay_power, shape, max_soft=0.0, reformulate=False): |
|
""" Samples a mean lambda from beta distribution parametrised by alpha, creates a low frequency image and binarises |
|
it based on this lambda |
|
|
|
:param alpha: Alpha value for beta distribution from which to sample mean of mask |
|
:param decay_power: Decay power for frequency decay prop 1/f**d |
|
:param shape: Shape of desired mask, list up to 3 dims |
|
:param max_soft: Softening value between 0 and 0.5 which smooths hard edges in the mask. |
|
:param reformulate: If True, uses the reformulation of [1]. |
|
""" |
|
if isinstance(shape, int): |
|
shape = (shape, ) |
|
|
|
# Choose lambda |
|
lam = sample_lam(alpha, reformulate) |
|
|
|
# Make mask, get mean / std |
|
mask = make_low_freq_image(decay_power, shape) |
|
mask = binarise_mask(mask, lam, shape, max_soft) |
|
|
|
return float(lam), mask |
|
|
|
|
|
def sample_and_apply(x, |
|
alpha, |
|
decay_power, |
|
shape, |
|
max_soft=0.0, |
|
reformulate=False): |
|
""" |
|
|
|
:param x: Image batch on which to apply fmix of shape [b, c, shape*] |
|
:param alpha: Alpha value for beta distribution from which to sample mean of mask |
|
:param decay_power: Decay power for frequency decay prop 1/f**d |
|
:param shape: Shape of desired mask, list up to 3 dims |
|
:param max_soft: Softening value between 0 and 0.5 which smooths hard edges in the mask. |
|
:param reformulate: If True, uses the reformulation of [1]. |
|
:return: mixed input, permutation indices, lambda value of mix, |
|
""" |
|
lam, mask = sample_mask(alpha, decay_power, shape, max_soft, reformulate) |
|
index = np.random.permutation(x.shape[0]) |
|
|
|
x1, x2 = x * mask, x[index] * (1 - mask) |
|
return x1 + x2, index, lam |
|
|
|
|
|
class FMixBase: |
|
""" FMix augmentation |
|
|
|
Args: |
|
decay_power (float): Decay power for frequency decay prop 1/f**d |
|
alpha (float): Alpha value for beta distribution from which to sample mean of mask |
|
size ([int] | [int, int] | [int, int, int]): Shape of desired mask, list up to 3 dims |
|
max_soft (float): Softening value between 0 and 0.5 which smooths hard edges in the mask. |
|
reformulate (bool): If True, uses the reformulation of [1]. |
|
""" |
|
|
|
def __init__(self, |
|
decay_power=3, |
|
alpha=1, |
|
size=(32, 32), |
|
max_soft=0.0, |
|
reformulate=False): |
|
super().__init__() |
|
self.decay_power = decay_power |
|
self.reformulate = reformulate |
|
self.size = size |
|
self.alpha = alpha |
|
self.max_soft = max_soft |
|
self.index = None |
|
self.lam = None |
|
|
|
def __call__(self, x): |
|
raise NotImplementedError |
|
|
|
def loss(self, *args, **kwargs): |
|
raise NotImplementedError
|
|
|