You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

217 lines
7.1 KiB

# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import random
import numpy as np
from scipy.stats import beta
def fftfreqnd(h, w=None, z=None):
""" Get bin values for discrete fourier transform of size (h, w, z)
:param h: Required, first dimension size
:param w: Optional, second dimension size
:param z: Optional, third dimension size
"""
fz = fx = 0
fy = np.fft.fftfreq(h)
if w is not None:
fy = np.expand_dims(fy, -1)
if w % 2 == 1:
fx = np.fft.fftfreq(w)[:w // 2 + 2]
else:
fx = np.fft.fftfreq(w)[:w // 2 + 1]
if z is not None:
fy = np.expand_dims(fy, -1)
if z % 2 == 1:
fz = np.fft.fftfreq(z)[:, None]
else:
fz = np.fft.fftfreq(z)[:, None]
return np.sqrt(fx * fx + fy * fy + fz * fz)
def get_spectrum(freqs, decay_power, ch, h, w=0, z=0):
""" Samples a fourier image with given size and frequencies decayed by decay power
:param freqs: Bin values for the discrete fourier transform
:param decay_power: Decay power for frequency decay prop 1/f**d
:param ch: Number of channels for the resulting mask
:param h: Required, first dimension size
:param w: Optional, second dimension size
:param z: Optional, third dimension size
"""
scale = np.ones(1) / (np.maximum(freqs, np.array([1. / max(w, h, z)]))
**decay_power)
param_size = [ch] + list(freqs.shape) + [2]
param = np.random.randn(*param_size)
scale = np.expand_dims(scale, -1)[None, :]
return scale * param
def make_low_freq_image(decay, shape, ch=1):
""" Sample a low frequency image from fourier space
:param decay_power: Decay power for frequency decay prop 1/f**d
:param shape: Shape of desired mask, list up to 3 dims
:param ch: Number of channels for desired mask
"""
freqs = fftfreqnd(*shape)
spectrum = get_spectrum(freqs, decay, ch,
*shape) #.reshape((1, *shape[:-1], -1))
spectrum = spectrum[:, 0] + 1j * spectrum[:, 1]
mask = np.real(np.fft.irfftn(spectrum, shape))
if len(shape) == 1:
mask = mask[:1, :shape[0]]
if len(shape) == 2:
mask = mask[:1, :shape[0], :shape[1]]
if len(shape) == 3:
mask = mask[:1, :shape[0], :shape[1], :shape[2]]
mask = mask
mask = (mask - mask.min())
mask = mask / mask.max()
return mask
def sample_lam(alpha, reformulate=False):
""" Sample a lambda from symmetric beta distribution with given alpha
:param alpha: Alpha value for beta distribution
:param reformulate: If True, uses the reformulation of [1].
"""
if reformulate:
lam = beta.rvs(alpha + 1, alpha)
else:
lam = beta.rvs(alpha, alpha)
return lam
def binarise_mask(mask, lam, in_shape, max_soft=0.0):
""" Binarises a given low frequency image such that it has mean lambda.
:param mask: Low frequency image, usually the result of `make_low_freq_image`
:param lam: Mean value of final mask
:param in_shape: Shape of inputs
:param max_soft: Softening value between 0 and 0.5 which smooths hard edges in the mask.
:return:
"""
idx = mask.reshape(-1).argsort()[::-1]
mask = mask.reshape(-1)
num = math.ceil(lam * mask.size) if random.random() > 0.5 else math.floor(
lam * mask.size)
eff_soft = max_soft
if max_soft > lam or max_soft > (1 - lam):
eff_soft = min(lam, 1 - lam)
soft = int(mask.size * eff_soft)
num_low = int(num - soft)
num_high = int(num + soft)
mask[idx[:num_high]] = 1
mask[idx[num_low:]] = 0
mask[idx[num_low:num_high]] = np.linspace(1, 0, (num_high - num_low))
mask = mask.reshape((1, 1, in_shape[0], in_shape[1]))
return mask
def sample_mask(alpha, decay_power, shape, max_soft=0.0, reformulate=False):
""" Samples a mean lambda from beta distribution parametrised by alpha, creates a low frequency image and binarises
it based on this lambda
:param alpha: Alpha value for beta distribution from which to sample mean of mask
:param decay_power: Decay power for frequency decay prop 1/f**d
:param shape: Shape of desired mask, list up to 3 dims
:param max_soft: Softening value between 0 and 0.5 which smooths hard edges in the mask.
:param reformulate: If True, uses the reformulation of [1].
"""
if isinstance(shape, int):
shape = (shape, )
# Choose lambda
lam = sample_lam(alpha, reformulate)
# Make mask, get mean / std
mask = make_low_freq_image(decay_power, shape)
mask = binarise_mask(mask, lam, shape, max_soft)
return float(lam), mask
def sample_and_apply(x,
alpha,
decay_power,
shape,
max_soft=0.0,
reformulate=False):
"""
:param x: Image batch on which to apply fmix of shape [b, c, shape*]
:param alpha: Alpha value for beta distribution from which to sample mean of mask
:param decay_power: Decay power for frequency decay prop 1/f**d
:param shape: Shape of desired mask, list up to 3 dims
:param max_soft: Softening value between 0 and 0.5 which smooths hard edges in the mask.
:param reformulate: If True, uses the reformulation of [1].
:return: mixed input, permutation indices, lambda value of mix,
"""
lam, mask = sample_mask(alpha, decay_power, shape, max_soft, reformulate)
index = np.random.permutation(x.shape[0])
x1, x2 = x * mask, x[index] * (1 - mask)
return x1 + x2, index, lam
class FMixBase:
""" FMix augmentation
Args:
decay_power (float): Decay power for frequency decay prop 1/f**d
alpha (float): Alpha value for beta distribution from which to sample mean of mask
size ([int] | [int, int] | [int, int, int]): Shape of desired mask, list up to 3 dims
max_soft (float): Softening value between 0 and 0.5 which smooths hard edges in the mask.
reformulate (bool): If True, uses the reformulation of [1].
"""
def __init__(self,
decay_power=3,
alpha=1,
size=(32, 32),
max_soft=0.0,
reformulate=False):
super().__init__()
self.decay_power = decay_power
self.reformulate = reformulate
self.size = size
self.alpha = alpha
self.max_soft = max_soft
self.index = None
self.lam = None
def __call__(self, x):
raise NotImplementedError
def loss(self, *args, **kwargs):
raise NotImplementedError