[initial commit]

main
tiankeyu 2 years ago
commit 6fda3bf645
  1. 7
      .gitignore
  2. 395
      LICENSE
  3. 112
      decoder.py
  4. 142
      dist.py
  5. 207
      encoder.py
  6. 87
      launch.py
  7. 156
      main.py
  8. 83
      models/__init__.py
  9. 212
      models/convnext.py
  10. 105
      models/resnet.py
  11. 167
      sampler.py
  12. 49
      scripts/pt.sh
  13. 313
      spark.py
  14. 172
      utils/imagenet.py
  15. 73
      utils/lr_control.py
  16. 163
      utils/meta.py
  17. 273
      utils/misc.py
  18. 160
      utils/optim.py

7
.gitignore vendored

@ -0,0 +1,7 @@
*.swp
**/__pycache__/**
.idea/*
ckpt/
*.pth
*.log
*.txt

@ -0,0 +1,395 @@
Attribution 4.0 International
=======================================================================
Creative Commons Corporation ("Creative Commons") is not a law firm and
does not provide legal services or legal advice. Distribution of
Creative Commons public licenses does not create a lawyer-client or
other relationship. Creative Commons makes its licenses and related
information available on an "as-is" basis. Creative Commons gives no
warranties regarding its licenses, any material licensed under their
terms and conditions, or any related information. Creative Commons
disclaims all liability for damages resulting from their use to the
fullest extent possible.
Using Creative Commons Public Licenses
Creative Commons public licenses provide a standard set of terms and
conditions that creators and other rights holders may use to share
original works of authorship and other material subject to copyright
and certain other rights specified in the public license below. The
following considerations are for informational purposes only, are not
exhaustive, and do not form part of our licenses.
Considerations for licensors: Our public licenses are
intended for use by those authorized to give the public
permission to use material in ways otherwise restricted by
copyright and certain other rights. Our licenses are
irrevocable. Licensors should read and understand the terms
and conditions of the license they choose before applying it.
Licensors should also secure all rights necessary before
applying our licenses so that the public can reuse the
material as expected. Licensors should clearly mark any
material not subject to the license. This includes other CC-
licensed material, or material used under an exception or
limitation to copyright. More considerations for licensors:
wiki.creativecommons.org/Considerations_for_licensors
Considerations for the public: By using one of our public
licenses, a licensor grants the public permission to use the
licensed material under specified terms and conditions. If
the licensor's permission is not necessary for any reason--for
example, because of any applicable exception or limitation to
copyright--then that use is not regulated by the license. Our
licenses grant only permissions under copyright and certain
other rights that a licensor has authority to grant. Use of
the licensed material may still be restricted for other
reasons, including because others have copyright or other
rights in the material. A licensor may make special requests,
such as asking that all changes be marked or described.
Although not required by our licenses, you are encouraged to
respect those requests where reasonable. More considerations
for the public:
wiki.creativecommons.org/Considerations_for_licensees
=======================================================================
Creative Commons Attribution 4.0 International Public License
By exercising the Licensed Rights (defined below), You accept and agree
to be bound by the terms and conditions of this Creative Commons
Attribution 4.0 International Public License ("Public License"). To the
extent this Public License may be interpreted as a contract, You are
granted the Licensed Rights in consideration of Your acceptance of
these terms and conditions, and the Licensor grants You such rights in
consideration of benefits the Licensor receives from making the
Licensed Material available under these terms and conditions.
Section 1 -- Definitions.
a. Adapted Material means material subject to Copyright and Similar
Rights that is derived from or based upon the Licensed Material
and in which the Licensed Material is translated, altered,
arranged, transformed, or otherwise modified in a manner requiring
permission under the Copyright and Similar Rights held by the
Licensor. For purposes of this Public License, where the Licensed
Material is a musical work, performance, or sound recording,
Adapted Material is always produced where the Licensed Material is
synched in timed relation with a moving image.
b. Adapter's License means the license You apply to Your Copyright
and Similar Rights in Your contributions to Adapted Material in
accordance with the terms and conditions of this Public License.
c. Copyright and Similar Rights means copyright and/or similar rights
closely related to copyright including, without limitation,
performance, broadcast, sound recording, and Sui Generis Database
Rights, without regard to how the rights are labeled or
categorized. For purposes of this Public License, the rights
specified in Section 2(b)(1)-(2) are not Copyright and Similar
Rights.
d. Effective Technological Measures means those measures that, in the
absence of proper authority, may not be circumvented under laws
fulfilling obligations under Article 11 of the WIPO Copyright
Treaty adopted on December 20, 1996, and/or similar international
agreements.
e. Exceptions and Limitations means fair use, fair dealing, and/or
any other exception or limitation to Copyright and Similar Rights
that applies to Your use of the Licensed Material.
f. Licensed Material means the artistic or literary work, database,
or other material to which the Licensor applied this Public
License.
g. Licensed Rights means the rights granted to You subject to the
terms and conditions of this Public License, which are limited to
all Copyright and Similar Rights that apply to Your use of the
Licensed Material and that the Licensor has authority to license.
h. Licensor means the individual(s) or entity(ies) granting rights
under this Public License.
i. Share means to provide material to the public by any means or
process that requires permission under the Licensed Rights, such
as reproduction, public display, public performance, distribution,
dissemination, communication, or importation, and to make material
available to the public including in ways that members of the
public may access the material from a place and at a time
individually chosen by them.
j. Sui Generis Database Rights means rights other than copyright
resulting from Directive 96/9/EC of the European Parliament and of
the Council of 11 March 1996 on the legal protection of databases,
as amended and/or succeeded, as well as other essentially
equivalent rights anywhere in the world.
k. You means the individual or entity exercising the Licensed Rights
under this Public License. Your has a corresponding meaning.
Section 2 -- Scope.
a. License grant.
1. Subject to the terms and conditions of this Public License,
the Licensor hereby grants You a worldwide, royalty-free,
non-sublicensable, non-exclusive, irrevocable license to
exercise the Licensed Rights in the Licensed Material to:
a. reproduce and Share the Licensed Material, in whole or
in part; and
b. produce, reproduce, and Share Adapted Material.
2. Exceptions and Limitations. For the avoidance of doubt, where
Exceptions and Limitations apply to Your use, this Public
License does not apply, and You do not need to comply with
its terms and conditions.
3. Term. The term of this Public License is specified in Section
6(a).
4. Media and formats; technical modifications allowed. The
Licensor authorizes You to exercise the Licensed Rights in
all media and formats whether now known or hereafter created,
and to make technical modifications necessary to do so. The
Licensor waives and/or agrees not to assert any right or
authority to forbid You from making technical modifications
necessary to exercise the Licensed Rights, including
technical modifications necessary to circumvent Effective
Technological Measures. For purposes of this Public License,
simply making modifications authorized by this Section 2(a)
(4) never produces Adapted Material.
5. Downstream recipients.
a. Offer from the Licensor -- Licensed Material. Every
recipient of the Licensed Material automatically
receives an offer from the Licensor to exercise the
Licensed Rights under the terms and conditions of this
Public License.
b. No downstream restrictions. You may not offer or impose
any additional or different terms or conditions on, or
apply any Effective Technological Measures to, the
Licensed Material if doing so restricts exercise of the
Licensed Rights by any recipient of the Licensed
Material.
6. No endorsement. Nothing in this Public License constitutes or
may be construed as permission to assert or imply that You
are, or that Your use of the Licensed Material is, connected
with, or sponsored, endorsed, or granted official status by,
the Licensor or others designated to receive attribution as
provided in Section 3(a)(1)(A)(i).
b. Other rights.
1. Moral rights, such as the right of integrity, are not
licensed under this Public License, nor are publicity,
privacy, and/or other similar personality rights; however, to
the extent possible, the Licensor waives and/or agrees not to
assert any such rights held by the Licensor to the limited
extent necessary to allow You to exercise the Licensed
Rights, but not otherwise.
2. Patent and trademark rights are not licensed under this
Public License.
3. To the extent possible, the Licensor waives any right to
collect royalties from You for the exercise of the Licensed
Rights, whether directly or through a collecting society
under any voluntary or waivable statutory or compulsory
licensing scheme. In all other cases the Licensor expressly
reserves any right to collect such royalties.
Section 3 -- License Conditions.
Your exercise of the Licensed Rights is expressly made subject to the
following conditions.
a. Attribution.
1. If You Share the Licensed Material (including in modified
form), You must:
a. retain the following if it is supplied by the Licensor
with the Licensed Material:
i. identification of the creator(s) of the Licensed
Material and any others designated to receive
attribution, in any reasonable manner requested by
the Licensor (including by pseudonym if
designated);
ii. a copyright notice;
iii. a notice that refers to this Public License;
iv. a notice that refers to the disclaimer of
warranties;
v. a URI or hyperlink to the Licensed Material to the
extent reasonably practicable;
b. indicate if You modified the Licensed Material and
retain an indication of any previous modifications; and
c. indicate the Licensed Material is licensed under this
Public License, and include the text of, or the URI or
hyperlink to, this Public License.
2. You may satisfy the conditions in Section 3(a)(1) in any
reasonable manner based on the medium, means, and context in
which You Share the Licensed Material. For example, it may be
reasonable to satisfy the conditions by providing a URI or
hyperlink to a resource that includes the required
information.
3. If requested by the Licensor, You must remove any of the
information required by Section 3(a)(1)(A) to the extent
reasonably practicable.
4. If You Share Adapted Material You produce, the Adapter's
License You apply must not prevent recipients of the Adapted
Material from complying with this Public License.
Section 4 -- Sui Generis Database Rights.
Where the Licensed Rights include Sui Generis Database Rights that
apply to Your use of the Licensed Material:
a. for the avoidance of doubt, Section 2(a)(1) grants You the right
to extract, reuse, reproduce, and Share all or a substantial
portion of the contents of the database;
b. if You include all or a substantial portion of the database
contents in a database in which You have Sui Generis Database
Rights, then the database in which You have Sui Generis Database
Rights (but not its individual contents) is Adapted Material; and
c. You must comply with the conditions in Section 3(a) if You Share
all or a substantial portion of the contents of the database.
For the avoidance of doubt, this Section 4 supplements and does not
replace Your obligations under this Public License where the Licensed
Rights include other Copyright and Similar Rights.
Section 5 -- Disclaimer of Warranties and Limitation of Liability.
a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
c. The disclaimer of warranties and limitation of liability provided
above shall be interpreted in a manner that, to the extent
possible, most closely approximates an absolute disclaimer and
waiver of all liability.
Section 6 -- Term and Termination.
a. This Public License applies for the term of the Copyright and
Similar Rights licensed here. However, if You fail to comply with
this Public License, then Your rights under this Public License
terminate automatically.
b. Where Your right to use the Licensed Material has terminated under
Section 6(a), it reinstates:
1. automatically as of the date the violation is cured, provided
it is cured within 30 days of Your discovery of the
violation; or
2. upon express reinstatement by the Licensor.
For the avoidance of doubt, this Section 6(b) does not affect any
right the Licensor may have to seek remedies for Your violations
of this Public License.
c. For the avoidance of doubt, the Licensor may also offer the
Licensed Material under separate terms or conditions or stop
distributing the Licensed Material at any time; however, doing so
will not terminate this Public License.
d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
License.
Section 7 -- Other Terms and Conditions.
a. The Licensor shall not be bound by any additional or different
terms or conditions communicated by You unless expressly agreed.
b. Any arrangements, understandings, or agreements regarding the
Licensed Material not stated herein are separate from and
independent of the terms and conditions of this Public License.
Section 8 -- Interpretation.
a. For the avoidance of doubt, this Public License does not, and
shall not be interpreted to, reduce, limit, restrict, or impose
conditions on any use of the Licensed Material that could lawfully
be made without permission under this Public License.
b. To the extent possible, if any provision of this Public License is
deemed unenforceable, it shall be automatically reformed to the
minimum extent necessary to make it enforceable. If the provision
cannot be reformed, it shall be severed from this Public License
without affecting the enforceability of the remaining terms and
conditions.
c. No term or condition of this Public License will be waived and no
failure to comply consented to unless expressly agreed to by the
Licensor.
d. Nothing in this Public License constitutes or may be interpreted
as a limitation upon, or waiver of, any privileges and immunities
that apply to the Licensor or You, including from the legal
processes of any jurisdiction or authority.
=======================================================================
Creative Commons is not a party to its public
licenses. Notwithstanding, Creative Commons may elect to apply one of
its public licenses to material it publishes and in those instances
will be considered the “Licensor.” The text of the Creative Commons
public licenses is dedicated to the public domain under the CC0 Public
Domain Dedication. Except for the limited purpose of indicating that
material is shared under a Creative Commons public license or as
otherwise permitted by the Creative Commons policies published at
creativecommons.org/policies, Creative Commons does not authorize the
use of the trademark "Creative Commons" or any other trademark or logo
of Creative Commons without its prior written consent including,
without limitation, in connection with any unauthorized modifications
to any of its public licenses or any other arrangements,
understandings, or agreements concerning use of licensed material. For
the avoidance of doubt, this paragraph does not form part of the
public licenses.
Creative Commons may be contacted at creativecommons.org.

@ -0,0 +1,112 @@
# Copyright (c) ByteDance, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import math
from timm.models.layers import trunc_normal_, DropPath, Mlp
import torch.nn as nn
from utils.misc import is_pow2n
_BN = None
class UNetBlock2x(nn.Module):
def __init__(self, cin, cout, cmid, last_act=True):
super().__init__()
if cmid == 0:
c_mid = cin
elif cmid == 1:
c_mid = (cin + cout) // 2
self.b = nn.Sequential(
nn.Conv2d(cin, c_mid, 3, 1, 1, bias=False), _BN(c_mid), nn.ReLU6(inplace=True),
nn.Conv2d(c_mid, cout, 3, 1, 1, bias=False), _BN(cout), (nn.ReLU6(inplace=True) if last_act else nn.Identity()),
)
def forward(self, x):
return self.b(x)
class DecoderConv(nn.Module):
def __init__(self, cin, cout, double, heavy, cmid):
super().__init__()
self.up = nn.ConvTranspose2d(cin, cin, kernel_size=4 if double else 2, stride=2, padding=1 if double else 0, bias=True)
ls = [UNetBlock2x(cin, (cin if i != heavy[1]-1 else cout), cmid=cmid, last_act=i != heavy[1]-1) for i in range(heavy[1])]
self.conv = nn.Sequential(*ls)
def forward(self, x):
x = self.up(x)
return self.conv(x)
class LightDecoder(nn.Module):
def __init__(self, decoder_fea_dim, upsample_ratio, double=False, heavy=None, cmid=0, sbn=False):
global _BN
_BN = nn.SyncBatchNorm if sbn else nn.BatchNorm2d
super().__init__()
self.fea_dim = decoder_fea_dim
if heavy is None:
heavy = [0, 1]
heavy[1] = max(1, heavy[1])
self.double_bool = double
self.heavy = heavy
self.cmid = cmid
self.sbn = sbn
assert is_pow2n(upsample_ratio)
n = round(math.log2(upsample_ratio))
channels = [self.fea_dim // 2**i for i in range(n+1)]
self.dec = nn.ModuleList([
DecoderConv(cin, cout, double, heavy, cmid) for (cin, cout) in zip(channels[:-1], channels[1:])
])
self.proj = nn.Conv2d(channels[-1], 3, kernel_size=1, stride=1, bias=True)
self.initialize()
def forward(self, to_dec):
x = 0
for i, d in enumerate(self.dec):
if i < len(to_dec) and to_dec[i] is not None:
x = x + to_dec[i]
x = self.dec[i](x)
return self.proj(x)
def num_para(self):
tot = sum(p.numel() for p in self.parameters())
para1 = para2 = 0
for m in self.dec.modules():
if isinstance(m, nn.ConvTranspose2d):
para1 += sum(p.numel() for p in m.parameters())
elif isinstance(m, nn.Conv2d):
para2 += sum(p.numel() for p in m.parameters())
return f'#para: {tot/1e6:.2f} (dconv={para1/1e6:.2f}, conv={para2/1e6:.2f}, ot={(tot-para1-para2)/1e6:.2f})'
def extra_repr(self) -> str:
return f'fea_dim={self.fea_dim}, dbl={self.double_bool}, heavy={self.heavy}, cmid={self.cmid}, sbn={self.sbn}'
def initialize(self):
for m in self.modules():
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Embedding):
trunc_normal_(m.weight, std=.02)
if m.padding_idx is not None:
m.weight.data[m.padding_idx].zero_()
elif isinstance(m, (nn.LayerNorm, nn.BatchNorm1d, nn.BatchNorm2d, nn.SyncBatchNorm)):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv2d):
trunc_normal_(m.weight, std=.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0.)

@ -0,0 +1,142 @@
# Copyright (c) ByteDance, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import functools
import os
from typing import List
from typing import Union
import torch
import torch.distributed as tdist
import torch.multiprocessing as mp
__rank, __local_rank, __world_size, __device = 0, 0, 1, 'cpu'
__initialized = False
def initialized():
return __initialized
def initialize(backend='nccl'):
# ref: https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py#L29
if mp.get_start_method(allow_none=True) is None:
mp.set_start_method('spawn')
global_rank, num_gpus = int(os.environ.get('RANK', 'error')), torch.cuda.device_count()
local_rank = global_rank % num_gpus
torch.cuda.set_device(local_rank)
tdist.init_process_group(backend=backend) # 不要 init_method='env://'
global __rank, __local_rank, __world_size, __device, __initialized
__local_rank = local_rank
__rank, __world_size = tdist.get_rank(), tdist.get_world_size()
__device = torch.empty(1).cuda().device
__initialized = True
assert tdist.is_initialized(), 'torch.distributed is not initialized!'
def get_rank():
return __rank
def get_local_rank():
return __local_rank
def get_world_size():
return __world_size
def get_device():
return __device
def is_master():
return __rank == 0
def is_local_master():
return __local_rank == 0
def parallelize(net, syncbn=False):
if syncbn:
net = torch.nn.SyncBatchNorm.convert_sync_batchnorm(net)
net = net.cuda()
net = torch.nn.parallel.DistributedDataParallel(net, device_ids=[get_local_rank()], find_unused_parameters=False, broadcast_buffers=False)
return net
def new_group(ranks: List[int]):
return tdist.new_group(ranks=ranks)
def barrier():
tdist.barrier()
def allreduce(t: torch.Tensor) -> None:
if not t.is_cuda:
cu = t.detach().cuda()
tdist.all_reduce(cu)
t.copy_(cu.cpu())
else:
tdist.all_reduce(t)
def allgather(t: torch.Tensor, cat=True) -> Union[List[torch.Tensor], torch.Tensor]:
if not t.is_cuda:
t = t.cuda()
ls = [torch.empty_like(t) for _ in range(__world_size)]
tdist.all_gather(ls, t)
if cat:
ls = torch.cat(ls, dim=0)
return ls
def broadcast(t: torch.Tensor, src_rank) -> None:
if not t.is_cuda:
cu = t.detach().cuda()
tdist.broadcast(cu, src=src_rank)
t.copy_(cu.cpu())
else:
tdist.broadcast(t, src=src_rank)
def dist_fmt_vals(val, fmt: Union[str, None] = '%.2f') -> Union[torch.Tensor, List]:
ts = torch.zeros(__world_size)
ts[__rank] = val
allreduce(ts)
if fmt is None:
return ts
return [fmt % v for v in ts.cpu().numpy().tolist()]
def master_only(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
force = kwargs.pop('force', False)
if force or is_master():
ret = func(*args, **kwargs)
else:
ret = None
barrier()
return ret
return wrapper
def local_master_only(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
force = kwargs.pop('force', False)
if force or is_local_master():
ret = func(*args, **kwargs)
else:
ret = None
barrier()
return ret
return wrapper

@ -0,0 +1,207 @@
# Copyright (c) ByteDance, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torch.nn as nn
from timm.models.layers import DropPath
_cur_active: torch.Tensor = None # B1ff
def _get_active_ex_or_ii(H, returning_active_ex=True):
downsample_raito = H // _cur_active.shape[-1]
active_ex = _cur_active.repeat_interleave(downsample_raito, 2).repeat_interleave(downsample_raito, 3)
return active_ex if returning_active_ex else active_ex.squeeze(1).nonzero(as_tuple=True) # ii: bi, hi, wi
def sp_conv_forward(self, x: torch.Tensor):
x = super(type(self), self).forward(x)
x *= _get_active_ex_or_ii(H=x.shape[2], returning_active_ex=True) # (BCHW) *= (B1HW)
return x
def sp_bn_forward(self, x: torch.Tensor):
ii = _get_active_ex_or_ii(H=x.shape[2], returning_active_ex=False)
bhwc = x.permute(0, 2, 3, 1)
nc = bhwc[ii]
nc = super(type(self), self).forward(nc) # BN1d forward
bchw = torch.zeros_like(bhwc)
bchw[ii] = nc
bchw = bchw.permute(0, 3, 1, 2)
return bchw
class SparseConv2d(nn.Conv2d):
forward = sp_conv_forward
class SparseMaxPooling(nn.MaxPool2d):
forward = sp_conv_forward
class SparseAvgPooling(nn.AvgPool2d):
forward = sp_conv_forward
class SparseBatchNorm2d(nn.BatchNorm1d):
forward = sp_bn_forward
class SparseSyncBatchNorm2d(nn.SyncBatchNorm):
forward = sp_bn_forward
class SparseConvNeXtLayerNorm(nn.LayerNorm):
r""" LayerNorm that supports two data formats: channels_last (default) or channels_first.
The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
shape (batch_size, height, width, channels) while channels_first corresponds to inputs
with shape (batch_size, channels, height, width).
"""
def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last", sparse=True):
if data_format not in ["channels_last", "channels_first"]:
raise NotImplementedError
super().__init__(normalized_shape, eps, elementwise_affine=True)
self.data_format = data_format
self.sparse = sparse
def forward(self, x):
if x.ndim == 4: # BHWC
if self.data_format == "channels_last":
if self.sparse:
ii = _get_active_ex_or_ii(H=x.shape[1], returning_active_ex=False)
nc = x[ii]
nc = super(SparseConvNeXtLayerNorm, self).forward(nc)
x = torch.zeros_like(x)
x[ii] = nc
return x
else:
return super(SparseConvNeXtLayerNorm, self).forward(x)
else: # channels_first
if self.sparse:
ii = _get_active_ex_or_ii(H=x.shape[2], returning_active_ex=False)
bhwc = x.permute(0, 2, 3, 1)
nc = bhwc[ii]
nc = super(SparseConvNeXtLayerNorm, self).forward(nc)
x = torch.zeros_like(bhwc)
x[ii] = nc
return x.permute(0, 3, 1, 2)
else:
u = x.mean(1, keepdim=True)
s = (x - u).pow(2).mean(1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.eps)
x = self.weight[:, None, None] * x + self.bias[:, None, None]
return x
else: # BLC or BC
if self.sparse:
raise NotImplementedError
else:
return super(SparseConvNeXtLayerNorm, self).forward(x)
def __repr__(self):
return super(SparseConvNeXtLayerNorm, self).__repr__()[:-1] + f', ch={self.data_format.split("_")[-1]}, sp={self.sparse})'
class SparseConvNeXtBlock(nn.Module):
r""" ConvNeXt Block. There are two equivalent implementations:
(1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
(2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
We use (2) as we find it slightly faster in PyTorch
Args:
dim (int): Number of input channels.
drop_path (float): Stochastic depth rate. Default: 0.0
layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
"""
def __init__(self, dim, drop_path=0., layer_scale_init_value=1e-6, sparse=True, ks=7):
super().__init__()
self.dwconv = nn.Conv2d(dim, dim, kernel_size=ks, padding=ks//2, groups=dim) # depthwise conv
self.norm = SparseConvNeXtLayerNorm(dim, eps=1e-6, sparse=sparse)
self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers
self.act = nn.GELU()
self.pwconv2 = nn.Linear(4 * dim, dim)
self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)),
requires_grad=True) if layer_scale_init_value > 0 else None
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.sparse = sparse
def forward(self, x):
input = x
x = self.dwconv(x)
x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
x = self.norm(x)
x = self.pwconv1(x)
x = self.act(x)
x = self.pwconv2(x)
if self.gamma is not None:
x = self.gamma * x
x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
if self.sparse:
x *= _get_active_ex_or_ii(H=x.shape[2], returning_active_ex=True)
x = input + self.drop_path(x)
return x
def __repr__(self):
return super(SparseConvNeXtBlock, self).__repr__()[:-1] + f', sp={self.sparse})'
class SparseEncoder(nn.Module):
def __init__(self, cnn, input_size, downsample_raito, encoder_fea_dim, verbose=False, sbn=False):
super(SparseEncoder, self).__init__()
self.sp_cnn = SparseEncoder.dense_model_to_sparse(m=cnn, verbose=verbose, sbn=sbn)
self.input_size, self.downsample_raito, self.fea_dim = input_size, downsample_raito, encoder_fea_dim
@staticmethod
def dense_model_to_sparse(m: nn.Module, verbose=False, sbn=False):
oup = m
if isinstance(m, nn.Conv2d):
m: nn.Conv2d
bias = m.bias is not None
oup = SparseConv2d(
m.in_channels, m.out_channels,
kernel_size=m.kernel_size, stride=m.stride, padding=m.padding,
dilation=m.dilation, groups=m.groups, bias=bias, padding_mode=m.padding_mode,
)
oup.weight.data.copy_(m.weight.data)
if bias:
oup.bias.data.copy_(m.bias.data)
elif isinstance(m, nn.MaxPool2d):
m: nn.MaxPool2d
oup = SparseMaxPooling(m.kernel_size, stride=m.stride, padding=m.padding, dilation=m.dilation, return_indices=m.return_indices, ceil_mode=m.ceil_mode)
elif isinstance(m, nn.AvgPool2d):
m: nn.AvgPool2d
oup = SparseAvgPooling(m.kernel_size, m.stride, m.padding, ceil_mode=m.ceil_mode, count_include_pad=m.count_include_pad, divisor_override=m.divisor_override)
elif isinstance(m, (nn.BatchNorm2d, nn.SyncBatchNorm)):
m: nn.BatchNorm2d
oup = (SparseSyncBatchNorm2d if sbn else SparseBatchNorm2d)(m.weight.shape[0], eps=m.eps, momentum=m.momentum, affine=m.affine, track_running_stats=m.track_running_stats)
oup.weight.data.copy_(m.weight.data)
oup.bias.data.copy_(m.bias.data)
oup.running_mean.data.copy_(m.running_mean.data)
oup.running_var.data.copy_(m.running_var.data)
oup.num_batches_tracked.data.copy_(m.num_batches_tracked.data)
if hasattr(m, "qconfig"):
oup.qconfig = m.qconfig
elif isinstance(m, nn.LayerNorm) and not isinstance(m, SparseConvNeXtLayerNorm):
m: nn.LayerNorm
oup = SparseConvNeXtLayerNorm(m.weight.shape[0], eps=m.eps)
oup.weight.data.copy_(m.weight.data)
oup.bias.data.copy_(m.bias.data)
elif isinstance(m, (nn.Conv1d,)):
raise NotImplementedError
for name, child in m.named_children():
oup.add_module(name, SparseEncoder.dense_model_to_sparse(child, verbose=verbose, sbn=sbn))
del m
return oup
def forward(self, x, pyramid):
return self.sp_cnn(x, pyramid=pyramid)

@ -0,0 +1,87 @@
#!/usr/bin/python3
# Copyright (c) ByteDance, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import functools
import os
import socket
import subprocess
import sys
from typing import List
echo = lambda info: os_system(f'echo "[$(date "+%m-%d-%H:%M:%S")] ({os.path.basename(sys._getframe().f_back.f_code.co_filename)}, line{sys._getframe().f_back.f_lineno})=> {info}"')
os_system = functools.partial(subprocess.call, shell=True)
os_system_get_stdout = lambda cmd: subprocess.run(cmd, shell=True, stdout=subprocess.PIPE).stdout.decode('utf-8')
def os_system_get_stdout_stderr(cmd):
sp = subprocess.run(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
return sp.stdout.decode('utf-8'), sp.stderr.decode('utf-8')
def __find_free_port():
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
# Binding to port 0 will cause the OS to find an available port for us
sock.bind(("", 0))
port = sock.getsockname()[1]
sock.close()
# NOTE: there is still a chance the port could be taken by other processes.
return port
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='PyTorch Distributed Launcher')
parser.add_argument('--main_py_relpath', type=str, default='main.py',
help='specify launcher script.')
# distributed environment
parser.add_argument('--num_nodes', type=int, default=1)
parser.add_argument('--ngpu_per_node', type=int, default=1)
parser.add_argument('--node_rank', type=int, default=0,
help='node rank, ranged from 0 to [dist_num_nodes]-1')
parser.add_argument('--master_address', type=str, default='128.0.0.0',
help='master address for distributed communication')
parser.add_argument('--master_port', type=int, default=30001,
help='master port for distributed communication')
# other args
known_args, other_args = parser.parse_known_args()
other_args: List[str]
echo(f'[other_args received by launch.py]: {other_args}')
main_args = other_args[-1]
main_args = '='.join(map(str.strip, main_args.split('=')))
main_args = main_args.split(' ')
for i, a in enumerate(main_args):
if len(a) and '=' not in a:
main_args[i] = f'{a}=1'
other_args[-1] = ' '.join(main_args)
echo(f'[final other_args]: {other_args[-1]}')
if known_args.num_nodes > 1:
os.environ['NPROC_PER_NODE'] = str(known_args.ngpu_per_node)
cmd = (
f'python3 -m torch.distributed.launch'
f' --nproc_per_node={known_args.ngpu_per_node}'
f' --nnodes={known_args.num_nodes}'
f' --node_rank={known_args.node_rank}'
f' --master_addr={known_args.master_address}'
f' --master_port={known_args.master_port}'
f' {known_args.main_py_relpath}'
f' {" ".join(other_args)}'
)
else:
cmd = (
f'python3 -m torch.distributed.launch'
f' --nproc_per_node={known_args.ngpu_per_node}'
f' --master_port={known_args.master_port}'
f' {known_args.main_py_relpath}'
f' {" ".join(other_args)}'
)
exit_code = subprocess.call(cmd, shell=True)
sys.exit(exit_code)

@ -0,0 +1,156 @@
# Copyright (c) ByteDance, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import math
import sys
import time
from functools import partial
from typing import List
import torch
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import DataLoader
import dist
import encoder
from decoder import LightDecoder
from models import build_sparse_encoder
from sampler import DistInfiniteBatchSampler, worker_init_fn
from spark import SparK
from utils import meta, misc, optim
from utils.imagenet import build_imagenet
from utils.lr_control import lr_wd_annealing, get_param_groups
def main_pt():
args: meta.Args = meta.init_dist_and_get_args()
print(f'global bs={args.glb_batch_size}, local bs={args.batch_size}')
print(f'initial args:\n{str(args)}')
args.log_epoch()
# build data
print(f'[build data for pre-training] ...\n')
dataset_train, _ = build_imagenet('pt', args.data_path, args.data_set, args.input_size, eval_crop_pct=None, rrc=args.rrc)
data_loader_train = DataLoader(
dataset=dataset_train, num_workers=args.num_workers, pin_memory=True,
batch_sampler=DistInfiniteBatchSampler(
dataset_len=len(dataset_train), glb_batch_size=args.glb_batch_size, seed=args.seed,
shuffle=True, filling=True, rank=dist.get_rank(), world_size=dist.get_world_size(),
), worker_init_fn=worker_init_fn
)
itrt_train = iter(data_loader_train)
iters_train = len(data_loader_train)
print(f'[dataloader] gbs={args.glb_batch_size}, lbs={args.batch_size}, iters_train={iters_train}')
# build models (encoder, decoder, and other components)
enc: encoder.SparseEncoder = build_sparse_encoder(args.model, input_size=args.input_size, sbn=args.sbn, drop_path_rate=args.dp, verbose=False)
dec = LightDecoder(args.dec_dim, enc.downsample_raito, double=args.double, heavy=args.hea, cmid=args.cmid, sbn=args.sbn)
spark = SparK(
sparse_encoder=enc, dense_decoder=dec, mask_ratio=args.mask, mask_ratio2=args.mask2, uniform=args.uni,
using_pe=args.pe, pix_norm=args.pn, dense_loss=args.den, loss_l2=args.loss_l2,
en_de_norm=args.en_de_norm, en_de_lin=args.en_de_lin, sbn=args.sbn, pyramid=args.py,
)
print(f'[PT model] model = {spark}\n')
spark.to(args.device)
model: DistributedDataParallel = DistributedDataParallel(spark, device_ids=[dist.get_local_rank()], find_unused_parameters=False, broadcast_buffers=False)
model_without_ddp: SparK = model.module
# build optimizer and lr_scheduler
param_groups: List[dict] = get_param_groups(model_without_ddp, nowd_keys={'pos_embed', 'mask_token', 'gamma'}, lr_scale=0)
opt_clz = {
'sgd': partial(torch.optim.SGD, momentum=0.9, nesterov=True),
'adamw': partial(torch.optim.AdamW, betas=(0.9, args.ada)),
'lamb': partial(optim.TimmLAMB, betas=(0.9, args.ada), max_grad_norm=args.clip),
}[args.opt]
optimizer = opt_clz(params=param_groups, lr=args.lr, weight_decay=0.0)
print(f'[optimizer] optimizer({opt_clz}) ={optimizer}\n')
# try to resume
next_ep, performance_desc = misc.load_checkpoint(args.resume, model_without_ddp, optimizer) if len(args.resume) else (0, '[no performance_desc]')
if next_ep >= args.ep:
# load from a complete checkpoint file
print(f' [*] [PT already done] Min/Last Recon Loss: {performance_desc}')
else:
# perform pre-training
start_time = time.time()
min_loss = 1e9
print(f'[PT start] from ep{next_ep}')
for ep in range(next_ep, args.ep):
if hasattr(itrt_train, 'set_epoch'):
itrt_train.set_epoch(ep)
stats, (sec, remain_time, finish_time) = pre_train_one_ep(ep, args, itrt_train, iters_train, model, optimizer)
last_loss = stats['last_loss']
min_loss = min(min_loss, last_loss)
performance_desc = f'{min_loss:.4f} {last_loss:.4f}'
print(f' [*] [ep{ep}] Min/Last Recon Loss: {performance_desc}, Remain: {remain_time}, Finish: {finish_time}')
args.cur_phase = 'PT'
args.cur_ep = f'{ep+1}/{args.ep}'
args.remain_time, args.finish_time = str(remain_time), str(finish_time)
args.last_loss = last_loss
args.log_epoch()
misc.save_checkpoint(f'ckpt-last.pth', args, ep, performance_desc, model_without_ddp.state_dict(with_config=True), optimizer.state_dict())
# finish pre-training
print('\n\n')
print(f' [*] [PT finished] Min/Last Recon Loss: {performance_desc}, Total Cost: {(time.time() - start_time) / 60 / 60:.1f}h')
print('\n\n')
misc.save_checkpoint(f'ckpt-final.pth', args, args.ep-1, performance_desc, model_without_ddp.state_dict(with_config=True), optimizer.state_dict())
args.remain_time, args.finish_time = '-', time.strftime("%m-%d %H:%M", time.localtime(time.time()))
args.log_epoch()
def pre_train_one_ep(ep, args, itrt_train, iters_train, model: DistributedDataParallel, optimizer):
model.train()
me = misc.MetricLogger(delimiter=' ')
me.add_meter('max_lr', misc.SmoothedValue(window_size=1, fmt='{value:.5f}'))
header = f'[PT] Epoch: [{ep:3d}/{args.ep}]'
optimizer.zero_grad()
early_clipping = args.clip > 0 and not hasattr(optimizer, 'global_grad_norm')
late_clipping = args.clip > 0 and hasattr(optimizer, 'global_grad_norm')
if early_clipping:
params_req_grad = [p for p in model.parameters() if p.requires_grad]
# for every batch do:
for it, (inp, _) in enumerate(me.log_every(iters_train, itrt_train, 3, header)):
# adjust lr and wd
g_it = it + ep*iters_train
min_lr, max_lr, min_wd, max_wd = lr_wd_annealing(optimizer, args.lr, args.wd, args.wde, g_it, args.wp_ep*iters_train, args.ep*iters_train)
# forward and backward
inp = inp.to(args.device, non_blocking=True)
SparK.forward
active_ex, rec, loss = model(inp)
optimizer.zero_grad()
loss.backward()
loss = loss.item()
if not math.isfinite(loss):
print(f'[rk{dist.get_rank():02d}] Loss is {loss}, stopping training!', force=True, flush=True)
sys.exit(-1)
# optimize
grad_norm = None
if early_clipping: grad_norm = torch.nn.utils.clip_grad_norm_(params_req_grad, args.clip)
optimizer.step()
if late_clipping: grad_norm = optimizer.global_grad_norm
torch.cuda.synchronize()
# log
me.update(last_loss=loss)
me.update(max_lr=max_lr)
if grad_norm is not None:
me.update(orig_norm=grad_norm)
me.synchronize_between_processes()
return {k: meter.global_avg for k, meter in me.meters.items()}, me.iter_time.time_preds((args.ep-1-ep) * (iters_train+10))
if __name__ == '__main__':
main_pt()

@ -0,0 +1,83 @@
# Copyright (c) ByteDance, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
from timm import create_model
from timm.loss import SoftTargetCrossEntropy
from timm.models.layers import drop
from models.convnext import ConvNeXt
from models.resnet import ResNet
_import_resnets_for_timm_registration = (ResNet,)
# log more
def _ex_repr(self):
return ', '.join(
f'{k}=' + (f'{v:g}' if isinstance(v, float) else str(v))
for k, v in vars(self).items()
if not k.startswith('_') and k != 'training'
and not isinstance(v, (torch.nn.Module, torch.Tensor))
)
for clz in (torch.nn.CrossEntropyLoss, SoftTargetCrossEntropy, drop.DropPath):
if hasattr(clz, 'extra_repr'):
clz.extra_repr = _ex_repr
else:
clz.__repr__ = lambda self: f'{type(self).__name__}({_ex_repr(self)})'
model_alias_to_fullname = {
'res50': 'resnet50',
'res101': 'resnet101',
'res152': 'resnet152',
'res200': 'resnet200',
'cnxS': 'convnext_small',
'cnxB': 'convnext_base',
'cnxL': 'convnext_large',
}
model_fullname_to_alias = {v: k for k, v in model_alias_to_fullname.items()}
pre_train_d = { # default drop_path_rate, num of para, FLOPs, downsample_ratio, num of channel
'resnet50': [dict(drop_path_rate=0.05), 25.6, 4.1, 32, 2048],
'resnet101': [dict(drop_path_rate=0.08), 44.5, 7.9, 32, 2048],
'resnet152': [dict(drop_path_rate=0.10), 60.2, 11.6, 32, 2048],
'resnet200': [dict(drop_path_rate=0.15), 64.7, 15.1, 32, 2048],
'convnext_small': [dict(sparse=True, drop_path_rate=0.2), 50.0, 8.7, 32, 768],
'convnext_base': [dict(sparse=True, drop_path_rate=0.3), 89.0, 15.4, 32, 1024],
'convnext_large': [dict(sparse=True, drop_path_rate=0.4), 198.0, 34.4, 32, 1536],
}
for v in pre_train_d.values():
v[0]['pretrained'] = False
v[0]['num_classes'] = 0
v[0]['global_pool'] = ''
def build_sparse_encoder(name: str, input_size: int, sbn=False, drop_path_rate=0.0, verbose=False):
from encoder import SparseEncoder
kwargs, params, flops, downsample_raito, fea_dim = pre_train_d[name]
if drop_path_rate != 0:
kwargs['drop_path_rate'] = drop_path_rate
print(f'[sparse_cnn] model kwargs={kwargs}')
cnn = create_model(name, **kwargs)
if hasattr(cnn, 'global_pool'):
if callable(cnn.global_pool):
cnn.global_pool = torch.nn.Identity()
elif isinstance(cnn.global_pool, str):
cnn.global_pool = ''
if not isinstance(downsample_raito, int) or not isinstance(fea_dim, int):
with torch.no_grad():
cnn.eval()
o = cnn(torch.rand(1, 3, input_size, input_size))
downsample_raito = input_size // o.shape[-1]
fea_dim = o.shape[1]
cnn.train()
print(f'[sparse_cnn] downsample_raito={downsample_raito}, fea_dim={fea_dim}')
return SparseEncoder(cnn, input_size=input_size, downsample_raito=downsample_raito, encoder_fea_dim=fea_dim, verbose=verbose, sbn=sbn)

@ -0,0 +1,212 @@
# Copyright (c) ByteDance, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
# This file is basically a copy to: https://github.com/facebookresearch/ConvNeXt/blob/main/models/convnext.py
import torch
import torch.nn as nn
from timm.models.layers import trunc_normal_
from timm.models.registry import register_model
from encoder import SparseConvNeXtBlock, SparseConvNeXtLayerNorm
class ConvNeXt(nn.Module):
r""" ConvNeXt
A PyTorch impl of : `A ConvNet for the 2020s` -
https://arxiv.org/pdf/2201.03545.pdf
Args:
in_chans (int): Number of input image channels. Default: 3
num_classes (int): Number of classes for classification head. Default: 1000
depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3]
dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768]
drop_path_rate (float): Stochastic depth rate. Default: 0.
layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1.
"""
def __init__(self, in_chans=3, num_classes=1000,
depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], drop_path_rate=0.,
layer_scale_init_value=1e-6, head_init_scale=1., global_pool='avg',
sparse=True,
):
super().__init__()
self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers
stem = nn.Sequential(
nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4),
SparseConvNeXtLayerNorm(dims[0], eps=1e-6, data_format="channels_first", sparse=sparse)
)
self.downsample_layers.append(stem)
for i in range(3):
downsample_layer = nn.Sequential(
SparseConvNeXtLayerNorm(dims[i], eps=1e-6, data_format="channels_first", sparse=sparse),
nn.Conv2d(dims[i], dims[i + 1], kernel_size=2, stride=2),
)
self.downsample_layers.append(downsample_layer)
self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple residual blocks
self.drop_path_rate = drop_path_rate
self.layer_scale_init_value = layer_scale_init_value
dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
cur = 0
for i in range(4):
stage = nn.Sequential(
*[SparseConvNeXtBlock(dim=dims[i], drop_path=dp_rates[cur + j],
layer_scale_init_value=layer_scale_init_value, sparse=sparse) for j in range(depths[i])]
)
self.stages.append(stage)
cur += depths[i]
self.depths = depths
self.apply(self._init_weights)
if num_classes > 0:
self.norm = SparseConvNeXtLayerNorm(dims[-1], eps=1e-6, sparse=False) # final norm layer for LE/FT; should not be sparse
self.fc = nn.Linear(dims[-1], num_classes)
# self.fc.weight.data.mul_(head_init_scale) # todo: perform this outside
# self.fc.bias.data.mul_(head_init_scale) # todo: perform this outside
else:
self.norm = nn.Identity()
self.fc = nn.Identity()
self.with_pooling = len(global_pool) > 0
def _init_weights(self, m):
if isinstance(m, (nn.Conv2d, nn.Linear)):
trunc_normal_(m.weight, std=.02)
nn.init.constant_(m.bias, 0)
def forward_features(self, x, pyramid: int): # pyramid: 0, 1, 2, 3, 4
ls = []
for i in range(4):
x = self.downsample_layers[i](x)
x = self.stages[i](x)
if pyramid:
ls.append(x)
if pyramid:
for i in range(len(ls)-pyramid-1, -1, -1):
del ls[i]
return [None] * (4 - pyramid) + ls
else:
if self.with_pooling:
x = x.mean([-2, -1]) # global average pooling, (N, C, H, W) -> (N, C)
return x
def forward(self, x, pyramid=0):
if pyramid == 0:
x = self.forward_features(x, pyramid=pyramid)
x = self.fc(self.norm(x))
return x
else:
return self.forward_features(x, pyramid=pyramid)
def get_classifier(self):
return self.fc
def extra_repr(self):
return f'drop_path_rate={self.drop_path_rate}, layer_scale_init_value={self.layer_scale_init_value:g}'
def get_layer_id_and_scale_exp(self, para_name: str):
N = 12 if self.depths[-2] > 9 else 6
if para_name.startswith("downsample_layers"):
stage_id = int(para_name.split('.')[1])
if stage_id == 0:
layer_id = 0
elif stage_id == 1 or stage_id == 2:
layer_id = stage_id + 1
else: # stage_id == 3:
layer_id = N
elif para_name.startswith("stages"):
stage_id = int(para_name.split('.')[1])
block_id = int(para_name.split('.')[2])
if stage_id == 0 or stage_id == 1:
layer_id = stage_id + 1
elif stage_id == 2:
layer_id = 3 + block_id // 3
else: # stage_id == 3:
layer_id = N
else:
layer_id = N + 1 # after backbone
return layer_id, N + 1 - layer_id
model_urls = {
"convnext_tiny_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth",
"convnext_small_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_small_1k_224_ema.pth",
"convnext_base_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_224_ema.pth",
"convnext_large_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth",
"convnext_tiny_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_224.pth",
"convnext_small_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_224.pth",
"convnext_base_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_224.pth",
"convnext_large_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_224.pth",
"convnext_xlarge_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_224.pth",
}
@register_model
def convnext_tiny(pretrained=False, in_22k=False, **kwargs):
model = ConvNeXt(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], **kwargs)
if pretrained:
url = model_urls['convnext_tiny_22k'] if in_22k else model_urls['convnext_tiny_1k']
checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu", check_hash=True)
model.load_state_dict(checkpoint["model"])
return model
@register_model
def convnext_small(pretrained=False, in_22k=False, **kwargs):
model = ConvNeXt(depths=[3, 3, 27, 3], dims=[96, 192, 384, 768], **kwargs)
if pretrained:
url = model_urls['convnext_small_22k'] if in_22k else model_urls['convnext_small_1k']
checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")
model.load_state_dict(checkpoint["model"])
return model
@register_model
def convnext_base(pretrained=False, in_22k=False, **kwargs):
model = ConvNeXt(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs)
if pretrained:
url = model_urls['convnext_base_22k'] if in_22k else model_urls['convnext_base_1k']
checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")
model.load_state_dict(checkpoint["model"])
return model
@register_model
def convnext_large(pretrained=False, in_22k=False, **kwargs):
model = ConvNeXt(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], **kwargs)
if pretrained:
url = model_urls['convnext_large_22k'] if in_22k else model_urls['convnext_large_1k']
checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")
model.load_state_dict(checkpoint["model"])
return model
@register_model
def convnext_xlarge(pretrained=False, in_22k=False, **kwargs):
model = ConvNeXt(depths=[3, 3, 27, 3], dims=[256, 512, 1024, 2048], **kwargs)
if pretrained:
assert in_22k, "only ImageNet-22K pre-trained ConvNeXt-XL is available; please set in_22k=True"
url = model_urls['convnext_xlarge_22k']
checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")
model.load_state_dict(checkpoint["model"])
return model
if __name__ == '__main__':
from timm.models import create_model
c = create_model('convnext_small', sparse=False)
with torch.no_grad():
x = torch.rand(2, 3, 224, 224)
print(c(x).shape)
print([None if f is None else f.shape for f in c(x, pyramid=1)])
print([None if f is None else f.shape for f in c(x, pyramid=2)])
print([None if f is None else f.shape for f in c(x, pyramid=3)])
print([None if f is None else f.shape for f in c(x, pyramid=4)])

@ -0,0 +1,105 @@
# Copyright (c) ByteDance, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import math
import torch.nn.functional as F
from timm.models.resnet import ResNet
def forward_features(self, x, pyramid: int): # pyramid: 0, 1, 2, 3, 4
x = self.conv1(x)
x = self.bn1(x)
x = self.act1(x)
x = self.maxpool(x)
ls = []
x = self.layer1(x)
if pyramid: ls.append(x)
x = self.layer2(x)
if pyramid: ls.append(x)
x = self.layer3(x)
if pyramid: ls.append(x)
x = self.layer4(x)
if pyramid: ls.append(x)
if pyramid:
for i in range(len(ls)-pyramid-1, -1, -1):
del ls[i]
return [None] * (4 - pyramid) + ls
else:
return x
def forward(self, x, pyramid=0):
if pyramid == 0:
x = self.forward_features(x, pyramid=pyramid)
x = self.global_pool(x)
if self.drop_rate:
x = F.dropout(x, p=float(self.drop_rate), training=self.training)
x = self.fc(x)
return x
else:
return self.forward_features(x, pyramid=pyramid)
def resnets_get_layer_id_and_scale_exp(self, para_name: str):
# stages:
# 50 : [3, 4, 6, 3]
# 101 : [3, 4, 23, 3]
# 152 : [3, 8, 36, 3]
# 200 : [3, 24, 36, 3]
# eca269d: [3, 30, 48, 8]
L2, L3 = len(self.layer2), len(self.layer3)
if L2 == 4 and L3 == 6:
blk2, blk3 = 2, 3
elif L2 == 4 and L3 == 23:
blk2, blk3 = 2, 3
elif L2 == 8 and L3 == 36:
blk2, blk3 = 4, 4
elif L2 == 24 and L3 == 36:
blk2, blk3 = 4, 4
elif L2 == 30 and L3 == 48:
blk2, blk3 = 5, 6
else:
raise NotImplementedError
N2, N3 = math.ceil(L2 / blk2 - 1e-5), math.ceil(L3 / blk3 - 1e-5)
N = 2 + N2 + N3
if para_name.startswith('layer'): # 1, 2, 3, 4, 5
stage_id, block_id = int(para_name.split('.')[0][5:]), int(para_name.split('.')[1])
if stage_id == 1:
layer_id = 1
elif stage_id == 2:
layer_id = 2 + block_id // blk2 # 2, 3
elif stage_id == 3:
layer_id = 2 + N2 + block_id // blk3 # r50: 4, 5 r101: 4, 5, ..., 11
else: # == 4
layer_id = N # r50: 6 r101: 12
elif para_name.startswith('fc.'):
layer_id = N+1 # r50: 7 r101: 13
else:
layer_id = 0
return layer_id, N+1 - layer_id # r50: 0-7, 7-0 r101: 0-13, 13-0
ResNet.get_layer_id_and_scale_exp = resnets_get_layer_id_and_scale_exp
ResNet.forward_features = forward_features
ResNet.forward = forward
if __name__ == '__main__':
import torch
from timm.models import create_model
r = create_model('resnet50')
with torch.no_grad():
print(r(torch.rand(2, 3, 224, 224)).shape)
print(r(torch.rand(2, 3, 224, 224), pyramid=1))
print(r(torch.rand(2, 3, 224, 224), pyramid=2))
print(r(torch.rand(2, 3, 224, 224), pyramid=3))
print(r(torch.rand(2, 3, 224, 224), pyramid=4))

@ -0,0 +1,167 @@
# Copyright (c) ByteDance, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import math
import random
import numpy as np
import torch
from torch.utils.data.sampler import Sampler
import dist
def worker_init_fn(worker_id):
# https://pytorch.org/docs/stable/notes/randomness.html#dataloader
worker_seed = torch.initial_seed() % 2 ** 32
np.random.seed(worker_seed)
random.seed(worker_seed)
class RASampler(Sampler):
"""Sampler that restricts data loading to a subset of the dataset for distributed,
with repeated augmentation.
It ensures that different each augmented version of a sample will be visible to a
different process (GPU).
Heavily based on 'torch.utils.data.DistributedSampler'.
This is borrowed from the DeiT Repo:
https://github.com/facebookresearch/deit/blob/main/samplers.py
"""
def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True, seed=0, repetitions=3):
if num_replicas is None:
num_replicas = dist.get_world_size()
if rank is None:
rank = dist.get_rank()
self.dataset = dataset
self.num_replicas = num_replicas
self.rank = rank
self.epoch = 0
self.num_samples = int(math.ceil(len(self.dataset) * float(repetitions) / self.num_replicas))
self.total_size = self.num_samples * self.num_replicas
self.num_selected_samples = int(math.floor(len(self.dataset) // 256 * 256 / self.num_replicas))
self.shuffle = shuffle
self.seed = seed
self.repetitions = repetitions
def __iter__(self):
if self.shuffle:
# Deterministically shuffle based on epoch
g = torch.Generator()
g.manual_seed(self.seed + self.epoch)
indices = torch.randperm(len(self.dataset), generator=g).tolist()
else:
indices = list(range(len(self.dataset)))
# Add extra samples to make it evenly divisible
indices = [ele for ele in indices for i in range(self.repetitions)]
indices += indices[: (self.total_size - len(indices))]
assert len(indices) == self.total_size
# Subsample
indices = indices[self.rank : self.total_size : self.num_replicas]
assert len(indices) == self.num_samples
return iter(indices[: self.num_selected_samples])
def __len__(self):
return self.num_selected_samples
def set_epoch(self, epoch):
self.epoch = epoch
class InfiniteBatchSampler(Sampler):
def __init__(self, dataset_len, batch_size, seed=0, filling=False, shuffle=True, drop_last=False):
self.dataset_len = dataset_len
self.batch_size = batch_size
self.iters_per_ep = dataset_len // batch_size if drop_last else (dataset_len + batch_size - 1) // batch_size
self.max_p = self.iters_per_ep * batch_size
self.filling = filling
self.shuffle = shuffle
self.epoch = 0
self.seed = seed
self.indices = self.gener_indices()
def gener_indices(self):
if self.shuffle:
g = torch.Generator()
g.manual_seed(self.epoch + self.seed)
indices = torch.randperm(self.dataset_len, generator=g).numpy()
else:
indices = torch.arange(self.dataset_len).numpy()
tails = self.batch_size - (self.dataset_len % self.batch_size)
if tails != self.batch_size and self.filling:
tails = indices[:tails]
np.random.shuffle(indices)
indices = np.concatenate((indices, tails))
# built-in list/tuple is faster than np.ndarray (when collating the data via a for-loop)
# noinspection PyTypeChecker
return tuple(indices.tolist())
def __iter__(self):
self.epoch = 0
while True:
self.epoch += 1
p, q = 0, 0
while p < self.max_p:
q = p + self.batch_size
yield self.indices[p:q]
p = q
if self.shuffle:
self.indices = self.gener_indices()
def __len__(self):
return self.iters_per_ep
class DistInfiniteBatchSampler(InfiniteBatchSampler):
def __init__(self, world_size, rank, dataset_len, glb_batch_size, seed=0, repeated_aug=0, filling=False, shuffle=True):
# from torchvision.models import ResNet50_Weights
# RA sampler: https://github.com/pytorch/vision/blob/5521e9d01ee7033b9ee9d421c1ef6fb211ed3782/references/classification/sampler.py
assert glb_batch_size % world_size == 0
self.world_size, self.rank = world_size, rank
self.dataset_len = dataset_len
self.glb_batch_size = glb_batch_size
self.batch_size = glb_batch_size // world_size
self.iters_per_ep = (dataset_len + glb_batch_size - 1) // glb_batch_size
self.filling = filling
self.shuffle = shuffle
self.repeated_aug = repeated_aug
self.epoch = 0
self.seed = seed
self.indices = self.gener_indices()
def gener_indices(self):
global_max_p = self.iters_per_ep * self.glb_batch_size # global_max_p % world_size must be 0 cuz glb_batch_size % world_size == 0
if self.shuffle:
g = torch.Generator()
g.manual_seed(self.epoch + self.seed)
global_indices = torch.randperm(self.dataset_len, generator=g)
if self.repeated_aug > 1:
global_indices = global_indices[:(self.dataset_len + self.repeated_aug - 1) // self.repeated_aug].repeat_interleave(self.repeated_aug, dim=0)[:global_max_p]
else:
global_indices = torch.arange(self.dataset_len)
filling = global_max_p - global_indices.shape[0]
if filling > 0 and self.filling:
global_indices = torch.cat((global_indices, global_indices[:filling]))
global_indices = tuple(global_indices.numpy().tolist())
seps = torch.linspace(0, len(global_indices), self.world_size + 1, dtype=torch.int)
local_indices = global_indices[seps[self.rank]:seps[self.rank + 1]]
self.max_p = len(local_indices)
return local_indices
if __name__ == '__main__':
W = 16
for rk in range(W):
ind = DistInfiniteBatchSampler(W, rk, 5024, 5024).gener_indices()
print(rk, len(ind))

@ -0,0 +1,49 @@
#!/usr/bin/env bash
# an example to do pre-training: (not that `/path/to/imagenet` should contain directories named `train` and `val`)
# > cd /path/to/SparK
# > bash ./scripts/pt.sh experiment_name /path/to/imagenet --num_nodes=1 --ngpu_per_node=8 --node_rank=0 --master_address=128.0.0.0 --master_port=30000 --model=res50 --ep=400
####### template begins #######
SCRIPTS_DIR=$(cd $(dirname $0); pwd)
cd "${SCRIPTS_DIR}"
cd ../
SPARK_DIR=$(pwd)
echo "SPARK_DIR=${SPARK_DIR}"
shopt -s expand_aliases
alias python=python3
alias to_scripts_dir='cd "${SCRIPTS_DIR}"'
alias to_spark_dir='cd "${SPARK_DIR}"'
alias print='echo "$(date +"[%m-%d %H:%M:%S]") (exp.sh)=>"'
function mkd() {
mkdir -p "$1" >/dev/null 2>&1
}
####### template ends #######
EXP_NAME=$1
DATA_PATH=$2
EXP_DIR="${SPARK_DIR}/output_${EXP_NAME}"
mkd "${EXP_DIR}"
print "===================== Args ====================="
print "EXP_NAME: ${EXP_NAME}"
print "DATA_PATH: ${DATA_PATH}"
print "EXP_DIR: ${EXP_DIR}"
print "[other_args sent to launch.py]: ${*:3}"
print "================================================"
print ""
print "============== Pretraining starts =============="
to_spark_dir
python launch.py \
--main_py_relpath main.py \
--exp_name "${EXP_NAME}" \
--data_path "${DATA_PATH}" \
--exp_dir "${EXP_DIR}" \
"${*:3}"
print "============== Pretraining ends =============="

@ -0,0 +1,313 @@
# Copyright (c) ByteDance, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import random
from pprint import pformat
from typing import List
import numpy as np
import torch
import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_STD, IMAGENET_DEFAULT_MEAN
from timm.models.layers import trunc_normal_
import encoder
from decoder import LightDecoder
class SparK(nn.Module):
def __init__(
self, sparse_encoder: encoder.SparseEncoder, dense_decoder: LightDecoder,
mask_ratio=0.6, mask_ratio2=0.6, uniform=False,
using_pe=True, pix_norm=0, dense_loss=False, loss_l2=True,
en_de_norm='bn', en_de_lin=True, sbn=False, pyramid=1, # 1 for single-scale pre-training; 4 for full-scale pre-training
):
super().__init__()
input_size, downsample_raito = sparse_encoder.input_size, sparse_encoder.downsample_raito
self.downsample_raito = downsample_raito
fmap_size = input_size // downsample_raito
self.fmap_size = fmap_size
if mask_ratio != mask_ratio2 and not uniform: # with an extra active site
k = 1 / fmap_size**2
mask_ratio = min(1, mask_ratio / (1-k))
mask_ratio2 = min(1, mask_ratio2 / (1-k))
self.mask_ratio = (mask_ratio, mask_ratio2)
self.ratios = torch.tensor([self.mask_ratio[0], self.mask_ratio[1], (self.mask_ratio[0] + self.mask_ratio[1]) / 2])
self.uniform = uniform
self.len_keep = round(fmap_size * fmap_size * (1-mask_ratio))
self.pix_norm = int(pix_norm)
self.sparse_encoder = sparse_encoder
self.dense_decoder = dense_decoder
self.sbn = sbn
self.pyramid = pyramid
en_de_norm = en_de_norm.lower()
self.en_de_norm_str = en_de_norm
self.en_de_lin_bool = en_de_lin
self.en_de_norms = nn.ModuleList()
self.en_de_lins = nn.ModuleList()
self.using_pe = using_pe
self.pos_embeds = nn.ParameterList()
self.mask_tokens = nn.ParameterList()
fea, d_fea, fmap = self.sparse_encoder.fea_dim, self.dense_decoder.fea_dim, fmap_size
for i in range(self.pyramid):
if en_de_norm == 'bn':
n = (encoder.SparseSyncBatchNorm2d if sbn else encoder.SparseBatchNorm2d)(fea)
elif en_de_norm == 'ln':
n = encoder.SparseConvNeXtLayerNorm(fea, data_format='channels_first', sparse=True)
else:
n = nn.Identity()
self.en_de_norms.append(n)
kernel_size = 1 if i <= 0 else 3
l = nn.Conv2d(fea, d_fea, kernel_size=kernel_size, stride=1, padding=kernel_size//2, bias=True)
print(f'[mid, py={self.pyramid}][edl {i}]: k={kernel_size}, #para = {sum(x.numel() for x in l.parameters())/1e6:.2f}')
if i == 0 and fea == d_fea:
l = nn.Identity()
self.en_de_lins.append(l)
if self.using_pe:
p = torch.from_numpy(get_2d_sincos_pos_embed(fea, fmap)).float()
p = p.reshape(1, fmap, fmap, fea).permute(0, 3, 1, 2).contiguous()
p = nn.Parameter(p, requires_grad=False)
self.pos_embeds.append(p)
p = nn.Parameter(torch.zeros(1, fea, 1, 1))
trunc_normal_(p, mean=0, std=.02, a=-.02, b=.02)
self.mask_tokens.append(p)
fea //= 2
d_fea //= 2
fmap *= 2
print(f'[mid, py={self.pyramid}][mask_tokens]: {tuple(p.numel() for p in self.mask_tokens)}')
self.loss_l2, self.dense_loss = loss_l2, dense_loss
m = torch.tensor(IMAGENET_DEFAULT_MEAN).view(1, 3, 1, 1)
s = torch.tensor(IMAGENET_DEFAULT_STD).view(1, 3, 1, 1)
self.register_buffer('imn_m', m)
self.register_buffer('imn_s', s)
# self.register_buffer('norm_black', (torch.ones(1, 3, input_size, input_size) * 0.45 - m) / s)
self.register_buffer('norm_black', torch.zeros(1, 3, input_size, input_size))
self.vis_active = self.vis_active_ex = self.vis_inp = self.vis_inp_mask = ...
def mask(self, shape, device, generator=None):
B, C, H, W = shape
f = self.fmap_size
if self.mask_ratio[0] == self.mask_ratio[1]:
len_keep = self.len_keep
elif self.uniform:
r = random.uniform(self.mask_ratio[0], self.mask_ratio[1])
len_keep = round(f * f * (1-r))
else:
i1, i2, i3, i4 = np.linspace(0, B, 4, dtype=int).tolist()
l1, l2, l3 = i2-i1, i3-i2, i4-i3
r1, r2, r3 = self.ratios[torch.randperm(3, generator=generator)].tolist()
r = torch.tensor([r1]*l1 + [r2]*l2 + [r3]*l3, device=device).view(-1, 1, 1)
active = torch.rand(B, f, f, device=device, generator=generator) >= r
rr, cc = torch.randint(low=0, high=f, size=(2, B), generator=generator).unbind(0)
active[torch.arange(B), rr, cc] = True # an extra active site
return active.unsqueeze_(1)
idx = torch.rand(B, f*f, generator=generator).argsort(dim=1)
idx = idx[:, :len_keep].to(device) # (B, len_keep)
return torch.zeros(B, f*f, dtype=torch.bool, device=device).scatter_(dim=1, index=idx, value=True).view(B, 1, f, f)
def forward(self, raw_inp: torch.Tensor, active=None):
inp_bchw = raw_inp
# spatial mask
if active is None:
active: torch.BoolTensor = self.mask(inp_bchw.shape, inp_bchw.device) # (B, 1, f, f)
encoder._cur_active = active
active_ex = active.repeat_interleave(self.downsample_raito, 2).repeat_interleave(self.downsample_raito, 3) # (B, 1, H, W)
masked_bchw = inp_bchw * active_ex
# get hierarchical encoded sparse features (a list containing four feature maps)
fea_bcffs: List[torch.Tensor] = self.sparse_encoder(masked_bchw, pyramid=self.pyramid)
fea_bcffs.reverse() # from the smallest feature map to the largest
cur_active = active
to_dec = []
for i, bcff in enumerate(fea_bcffs): # from the smallest feature map to the largest
if bcff is not None:
# fill in empty positions with [mask] embeddings
bcff = self.en_de_norms[i](bcff)
mask_tokens = self.mask_tokens[i].expand_as(bcff)
if self.using_pe:
mask_tokens = mask_tokens + self.pos_embeds[i].expand_as(bcff)
bcff = torch.where(cur_active.expand_as(bcff), bcff, mask_tokens)
bcff: torch.Tensor = self.en_de_lins[i](bcff)
to_dec.append(bcff)
cur_active = cur_active.repeat_interleave(2, dim=2).repeat_interleave(2, dim=3) # dilate the mask map
# decode and reconstruct
rec_bchw = self.dense_decoder(to_dec)
# calc loss
mean, var, spatial_loss = self.spatial_loss(raw_inp, rec_bchw, active)
return active_ex, rec_bchw, spatial_loss
def spatial_loss(self, inp, rec, active): # active: (B, 1, f, f)
mean = var = None
if self.pix_norm == 2:
mean, var = inp.mean(dim=(2, 3), keepdim=True), None
rec = rec + mean
inp = self.patchify(inp)
rec = self.patchify(rec)
# (B, L=fmap_size**2, N=downsample_raito**2 * C)
if self.pix_norm == 1:
mean = inp.mean(dim=-1, keepdim=True)
var = (inp.var(dim=-1, keepdim=True) + 1e-6)**.5
inp = (inp - mean) / var
loss_spa = (rec-inp)**2 if self.loss_l2 else (rec-inp).abs()
if self.dense_loss:
return mean, var, loss_spa.mean() # mean loss on all patches
else:
loss_spa = loss_spa.mean(dim=2, keepdim=False) # (B, L, C) => (B, L)
non_active = active.logical_not().int().view(active.shape[0], -1) # (B, 1, f, f) => (B, L)
return mean, var, loss_spa.mul_(non_active).sum() / (non_active.sum() + 1e-8) # mean loss on removed patches
def patchify(self, bchw):
p = self.downsample_raito
h = w = self.fmap_size
B, C = bchw.shape[:2]
bchw = bchw.reshape(shape=(B, C, h, p, w, p))
bchw = torch.einsum('bchpwq->bhwpqc', bchw)
bln = bchw.reshape(shape=(B, h*w, p**2 * C)) # (B, f*f, downsample_raito*downsample_raito*3)
return bln
def unpatchify(self, bln):
p = self.downsample_raito
h = w = self.fmap_size
B, C = bln.shape[0], bln.shape[-1] // p**2
bln = bln.reshape(shape=(B, h, w, p, p, C))
bln = torch.einsum('bhwpqc->bchpwq', bln)
bchw = bln.reshape(shape=(B, C, h * p, w * p))
return bchw
def denorm_for_vis(self, x, clamp):
x = x * self.imn_s
x += self.imn_m
if clamp:
x = torch.clamp(x, 0, 1)
return x
def __repr__(self):
return (
f'\n'
f'[SparK.config]: {pformat(self.get_config(), indent=2, width=250)}\n'
f'[SparK.structure]: {super(SparK, self).__repr__().replace(SparK.__name__, "")}\n'
f'[SparK.dec]: {self.dense_decoder.num_para()}'
)
def get_config(self):
return {
# self
'mask_ratio': self.mask_ratio[0], 'mask_ratio2': self.mask_ratio[1], 'uniform': self.uniform,
'using_pe': self.using_pe, 'pix_norm': self.pix_norm,
'dense_loss': self.dense_loss, 'loss_l2': self.loss_l2,
'en_de_norm': self.en_de_norm_str, 'en_de_lin': self.en_de_lin_bool, 'sbn': self.sbn, 'pyramid': self.pyramid,
# enc
'input_size': self.sparse_encoder.input_size,
# dec
'dec_fea_dim': self.dense_decoder.fea_dim, 'double': self.dense_decoder.double_bool, 'heavy': self.dense_decoder.heavy,
}
def state_dict(self, destination=None, prefix='', keep_vars=False, with_config=False):
state = super(SparK, self).state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars)
if with_config:
state['config'] = self.get_config() # todo: 似乎会引起DDP broadcast err??
return state
def load_state_dict(self, state_dict, strict=True):
config = state_dict.pop('config', None)
incompatible_keys = super(SparK, self).load_state_dict(state_dict, strict=strict)
if config is not None:
for k, v in self.get_config().items():
if config.get(k, None) != v:
err = f'[SparseMIM.load_state_dict] config mismatch: this.{k}={v} (ckpt.{k}={config.get(k, None)})'
if strict:
raise AttributeError(err)
else:
print(err)
return incompatible_keys
def _make_divisible(v, divisor=8, min_value=None):
"""
This function is taken from the original tf repo.
It ensures that all layers have a channel number that is divisible by 8
It can be seen here:
https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
"""
if min_value is None:
min_value = divisor
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
# Make sure that round down does not go down by more than 10%.
if new_v < 0.9 * v:
new_v += divisor
return new_v
def get_2d_sincos_pos_embed(embed_dim, grid_size):
"""
grid_size: int of the grid height and width
return:
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
"""
grid_h = np.arange(grid_size, dtype=np.float32)
grid_w = np.arange(grid_size, dtype=np.float32)
grid = np.meshgrid(grid_w, grid_h) # here w goes first
grid = np.stack(grid, axis=0)
grid = grid.reshape([2, 1, grid_size, grid_size])
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
return pos_embed
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
assert embed_dim % 2 == 0
# use half of dimensions to encode grid_h
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
return emb
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
"""
embed_dim: output dimension for each position
pos: a list of positions to be encoded: size (M,)
out: (M, D)
"""
assert embed_dim % 2 == 0
omega = np.arange(embed_dim // 2, dtype=np.float)
omega /= embed_dim / 2.
omega = 1. / 10000**omega # (D/2,)
pos = pos.reshape(-1) # (M,)
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
emb_sin = np.sin(out) # (M, D/2)
emb_cos = np.cos(out) # (M, D/2)
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
return emb
if __name__ == '__main__':
SparK.test_mask()
SparK.test_align()

@ -0,0 +1,172 @@
# Copyright (c) ByteDance, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import os
from typing import Any, Callable, Optional, Tuple
import PIL.Image as PImage
import torch
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, create_transform
from timm.data.transforms_factory import transforms_imagenet_eval
from torchvision.datasets.folder import DatasetFolder, IMG_EXTENSIONS
from torchvision.transforms import transforms
import dist
try:
from torchvision.transforms import InterpolationMode
interpolation = InterpolationMode.BICUBIC
except:
import PIL
interpolation = PIL.Image.BICUBIC
def pil_loader(path):
# open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
with open(path, 'rb') as f: img: PImage.Image = PImage.open(f).convert('RGB')
return img
class ImageNetDataset(DatasetFolder):
def __init__(
self,
root: str,
train: bool,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
is_valid_file: Optional[Callable[[str], bool]] = None,
max_cls_id: int = 1000,
only=-1,
):
for postfix in (os.path.sep, 'train', 'val'):
if root.endswith(postfix):
root = root[:-len(postfix)]
root = os.path.join(root, 'train' if train else 'val')
super(ImageNetDataset, self).__init__(
root,
# loader=ImageLoader(train),
loader=pil_loader,
extensions=IMG_EXTENSIONS if is_valid_file is None else None,
transform=transform, target_transform=target_transform, is_valid_file=is_valid_file
)
if only > 0:
g = torch.Generator()
g.manual_seed(0)
idx = torch.randperm(len(self.samples), generator=g).numpy().tolist()
ws = dist.get_world_size()
res = (max_cls_id * only) % ws
more = 0 if res == 0 else (ws - res)
max_total = max_cls_id * only + more
if (max_total // ws) % 2 == 1:
more += ws
max_total += ws
d = {c: [] for c in range(max_cls_id)}
max_len = {c: only for c in range(max_cls_id)}
for c in range(max_cls_id-more, max_cls_id):
max_len[c] += 1
total = 0
for i in idx:
path, target = self.samples[i]
if len(d[target]) < max_len[target]:
d[target].append((path, target))
total += 1
if total == max_total:
break
sp = []
[sp.extend(l) for l in d.values()]
print(f'[ds] more={more}, len(sp)={len(sp)}')
self.samples = tuple(sp)
self.targets = tuple([s[1] for s in self.samples])
else:
self.samples = tuple(filter(lambda item: item[-1] < max_cls_id, self.samples))
self.targets = tuple([s[1] for s in self.samples])
def __getitem__(self, index: int) -> Tuple[Any, Any]:
path, target = self.samples[index]
sample = self.loader(path)
if self.transform is not None:
sample = self.transform(sample)
if self.target_transform is not None:
target = self.target_transform(target)
return sample, target
def build_imagenet(mode, data_path, data_set, img_size, eval_crop_pct=None, rrc=0.3, aa='rand-m7-mstd0.5', re_prob=0.0, colorj=0.4):
mean, std = IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
norm = transforms.Normalize(mean=mean, std=std)
if img_size >= 384:
trans_val = transforms.Compose([
transforms.Resize((img_size, img_size), interpolation=interpolation),
transforms.ToTensor(),
norm,
])
else:
trans_val = transforms_imagenet_eval(
img_size=img_size, interpolation='bicubic', crop_pct=eval_crop_pct,
mean=mean, std=std
)
mode = mode.lower()
if mode == 'pt':
trans_train = transforms.Compose([
transforms.RandomResizedCrop(img_size, scale=(rrc, 1.0), interpolation=interpolation),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
norm,
])
elif mode == 'le':
trans_train = transforms.Compose([
transforms.RandomResizedCrop(img_size, interpolation=interpolation),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
norm,
])
else:
trans_train = create_transform(
is_training=True,
input_size=img_size,
auto_augment=aa,
interpolation='bicubic',
re_prob=re_prob,
re_mode='pixel',
re_count=1,
color_jitter=colorj,
mean=mean, std=std,
)
if data_path.endswith(os.path.sep):
data_path = data_path[:-len(os.path.sep)]
for postfix in ('train', 'val'):
if data_path.endswith(postfix):
data_path = data_path[:-len(postfix)]
if data_set == 'imn':
dataset_train = ImageNetDataset(root=data_path, transform=trans_train, train=True)
dataset_val = ImageNetDataset(root=data_path, transform=trans_val, train=False)
num_classes = 1000
else:
raise NotImplementedError
print_transform(trans_train, '[train]')
print_transform(trans_val, '[val]')
return dataset_train, dataset_val
def print_transform(transform, s):
print(f'Transform {s} = ')
for t in transform.transforms:
print(t)
print('---------------------------\n')

@ -0,0 +1,73 @@
# Copyright (c) ByteDance, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import math
from pprint import pformat
def lr_wd_annealing(optimizer, peak_lr, wd, wd_end, cur_it, wp_it, max_it):
wp_it = round(wp_it)
if cur_it < wp_it:
cur_lr = 0.005 * peak_lr + 0.995 * peak_lr * cur_it / wp_it
else:
ratio = (cur_it - wp_it) / (max_it-1 - wp_it)
cur_lr = 0.001 * peak_lr + 0.999 * peak_lr * (0.5 + 0.5 * math.cos(math.pi * ratio))
ratio = cur_it / (max_it-1)
cur_wd = wd_end + (wd - wd_end) * (0.5 + 0.5 * math.cos(math.pi * ratio))
inf = 1e6
min_lr, max_lr = inf, -1
min_wd, max_wd = inf, -1
for param_group in optimizer.param_groups:
param_group['lr'] = cur_lr * param_group.get('lr_scale', 1) # 'lr_scale' could be assigned
max_lr = max(max_lr, param_group['lr'])
min_lr = min(min_lr, param_group['lr'])
param_group['weight_decay'] = cur_wd * param_group.get('weight_decay_scale', 1)
max_wd = max(max_wd, param_group['weight_decay'])
if param_group['weight_decay'] > 0:
min_wd = min(min_wd, param_group['weight_decay'])
if min_lr == inf: min_lr = -1
if min_wd == inf: min_wd = -1
return min_lr, max_lr, min_wd, max_wd
def get_param_groups(model, nowd_keys=(), lr_scale=0.0):
with_lr_scale = hasattr(model, 'get_layer_id_and_scale_exp') and 0 < lr_scale < 1
print(f'[get_ft_param_groups][lr decay] with_lr_scale={with_lr_scale}, ft_lr_scale={lr_scale}')
para_groups, para_groups_dbg = {}, {}
for name, para in model.named_parameters():
if not para.requires_grad:
continue # frozen weights
if len(para.shape) == 1 or name.endswith('.bias') or any(k in name for k in nowd_keys):
wd_scale, group_name = 0., 'no_decay'
else:
wd_scale, group_name = 1., 'decay'
if with_lr_scale:
layer_id, scale_exp = model.get_layer_id_and_scale_exp(name)
group_name = f'layer{layer_id}_' + group_name
cur_lr_scale = lr_scale ** scale_exp
dbg = f'[layer {layer_id}][sc = {lr_scale} ** {scale_exp}]'
else:
cur_lr_scale = 1
dbg = f'[no scale]'
if group_name not in para_groups:
para_groups[group_name] = {'params': [], 'weight_decay_scale': wd_scale, 'lr_scale': cur_lr_scale}
para_groups_dbg[group_name] = {'params': [], 'weight_decay_scale': wd_scale, 'lr_scale': dbg}
para_groups[group_name]['params'].append(para)
para_groups_dbg[group_name]['params'].append(name)
for g in para_groups_dbg.values():
g['params'] = pformat(', '.join(g['params']), width=200)
print(f'[get_ft_param_groups] param groups = \n{pformat(para_groups_dbg, indent=2, width=250)}\n')
return list(para_groups.values())

@ -0,0 +1,163 @@
# Copyright (c) ByteDance, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import json
import os
import re
import sys
from tap import Tap
import dist
line_sep = f'\n{"=" * 80}\n'
class Args(Tap):
# environment
local_rank: int # useless
exp_name: str
data_path: str
exp_dir: str
log_txt_name: str = '/some/path/like/this/log.txt'
resume: str = ''
seed: int = 1
device: str = 'cpu'
# key MIM hp
mask: float = 0.6
mask2: float = -1
uni: bool = False
pe: bool = False
pn: int = 1
py: int = 4
# other MIM hp
den: bool = False
loss_l2: bool = True
en_de_norm: str = 'bn'
en_de_lin: bool = True
# encoder
model: str = 'res50'
model_alias: str = 'res50'
input_size: int = 224
sbn: bool = True
# decoder
dec_dim: int = 512 # [could be changed in `main.py`]
double: bool = True
hea: str = '0_1'
cmid: int = 0
# pre-training hyperparameters
glb_batch_size: int = 0
batch_size: int = 0 # batch size per GPU
dp: float = 0.0
base_lr: float = 2e-4
lr: float = None
wd: float = 0.04
wde: float = 0.2
ep: int = 1600
wp_ep: int = 40
clip: int = 5.
opt: str = ''
ada: float = 0.
# data hyperparameters
data_set: str = 'imn'
rrc: float = 0.67
bs: int = 4096
num_workers: int = 8
# would be added during runtime
cmd: str = ''
commit_id: str = ''
commit_msg: str = ''
last_loss = 1e9 # [would be changed in `main.py`]
cur_phase: str = '' # [would be changed in `main.py`]
cur_ep: str = '' # [would be changed in `main.py`]
remain_time: str = '' # [would be changed in `main.py`]
finish_time: str = '' # [would be changed in `main.py`]
first_logging: bool = True
@property
def is_convnext(self):
return 'convnext' in self.model or 'cnx' in self.model
@property
def is_resnet(self):
return 'res' in self.model or 'res' in self.model_alias
def __str__(self):
return re.sub(r"(\[LE-FT\]:\s*)('\s+')?", r'\1', super(Args, self).__str__())
def log_epoch(self):
if not dist.is_local_master():
return
if self.first_logging:
self.first_logging = False
with open(self.log_txt_name, 'w') as fp:
json.dump({
'name': self.exp_name, 'cmd': self.cmd, 'commit_id': self.commit_id,
'model': self.model, 'opt': self.opt,
}, fp)
print('', end='\n', file=fp)
with open(self.log_txt_name, 'a') as fp:
json.dump({
'cur': self.cur_phase, 'cur_ep': self.cur_ep,
'last_L': self.last_loss,
'rema': self.remain_time, 'fini': self.finish_time,
}, fp)
def init_dist_and_get_args():
from utils import misc
from models import model_alias_to_fullname, model_fullname_to_alias
# initialize
args = Args(explicit_bool=True).parse_args()
misc.init_distributed_environ(exp_dir=args.exp_dir)
# update args
args.cmd = ' '.join(sys.argv[1:])
args.commit_id = os.popen(f'git rev-parse HEAD').read().strip()
args.commit_msg = os.popen(f'git log -1').read().strip().splitlines()[-1].strip()
if args.model in model_alias_to_fullname.keys():
args.model = model_alias_to_fullname[args.model]
args.model_alias = model_fullname_to_alias[args.model]
args.device = dist.get_device()
args.batch_size = args.bs // dist.get_world_size()
args.glb_batch_size = args.batch_size * dist.get_world_size()
if args.is_resnet:
args.opt = args.opt or 'lamb'
args.ada = args.ada or 0.95
if args.is_convnext:
args.opt = args.opt or 'lamb'
args.ada = args.ada or 0.999
args.en_de_norm = 'ln'
args.opt = args.opt.lower()
args.lr = args.base_lr * args.glb_batch_size / 256
args.wde = args.wde or args.wd
if args.mask2 < 0:
args.mask2 = args.mask
args.mask, args.mask2 = min(args.mask, args.mask2), max(args.mask, args.mask2)
if args.py <= 0:
args.py = 1
args.hea = list(map(int, args.hea.split('_')))
args.log_txt_name = os.path.join(args.exp_dir, 'log.txt')
return args

@ -0,0 +1,273 @@
# Copyright (c) ByteDance, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import datetime
import functools
import os
import subprocess
import sys
import time
from collections import defaultdict, deque
from typing import Iterator
import numpy as np
import pytz
import torch
import torch.distributed as tdist
import dist
os_system = functools.partial(subprocess.call, shell=True)
os_system_get_stdout = lambda cmd: subprocess.run(cmd, shell=True, stdout=subprocess.PIPE).stdout.decode('utf-8')
def os_system_get_stdout_stderr(cmd):
sp = subprocess.run(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
return sp.stdout.decode('utf-8'), sp.stderr.decode('utf-8')
def is_pow2n(x):
return x > 0 and (x & (x - 1) == 0)
def time_str():
return datetime.datetime.now(tz=pytz.timezone('Asia/Shanghai')).strftime('[%m-%d %H:%M:%S]')
def init_distributed_environ(exp_dir):
dist.initialize()
dist.barrier()
import torch.backends.cudnn as cudnn
cudnn.benchmark = True
cudnn.deterministic = False
_set_print_only_on_master_proc(is_master=dist.is_local_master())
if dist.is_local_master() and len(exp_dir):
sys.stdout, sys.stderr = _SyncPrintToFile(exp_dir, stdout=True), _SyncPrintToFile(exp_dir, stdout=False)
def save_checkpoint(fname, args, epoch, performance_desc, model_without_ddp_state, optimizer_state):
checkpoint_path = os.path.join(args.exp_dir, fname)
if dist.is_local_master():
to_save = {
'args': str(args),
'arch': args.model,
'epoch': epoch,
'performance_desc': performance_desc,
'module': model_without_ddp_state,
'optimizer': optimizer_state,
}
torch.save(to_save, checkpoint_path)
dist.barrier()
def load_checkpoint(fname, model_without_ddp, optimizer):
print(f'[try to resume from file `{fname}`]')
checkpoint = torch.load(fname, map_location='cpu')
next_ep, performance_desc = checkpoint['epoch'] + 1, checkpoint['performance_desc']
missing, unexpected = model_without_ddp.load_state_dict(checkpoint['module'], strict=False)
print(f'[load_checkpoint] missing_keys={missing}')
print(f'[load_checkpoint] unexpected_keys={unexpected}')
print(f'[load_checkpoint] next_ep={next_ep}, performance_desc={performance_desc}')
if 'optimizer' in checkpoint:
optimizer.load_state_dict(checkpoint['optimizer'])
return next_ep, performance_desc
class SmoothedValue(object):
"""Track a series of values and provide access to smoothed values over a
window or the global series average.
"""
def __init__(self, window_size=20, fmt=None):
if fmt is None:
fmt = "{median:.4f} ({global_avg:.4f})"
self.deque = deque(maxlen=window_size)
self.total = 0.0
self.count = 0
self.fmt = fmt
def update(self, value, n=1):
self.deque.append(value)
self.count += n
self.total += value * n
def synchronize_between_processes(self):
"""
Warning: does not synchronize the deque!
"""
t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
tdist.barrier()
tdist.all_reduce(t)
t = t.tolist()
self.count = int(t[0])
self.total = t[1]
@property
def median(self):
d = torch.tensor(list(self.deque))
return d.median().item()
@property
def avg(self):
d = torch.tensor(list(self.deque), dtype=torch.float32)
return d.mean().item()
@property
def global_avg(self):
return self.total / self.count
@property
def max(self):
return max(self.deque)
@property
def value(self):
return self.deque[-1]
def time_preds(self, counts):
remain_secs = counts * self.median
remain_time = datetime.timedelta(seconds=round(remain_secs))
finish_time = time.strftime("%m-%d %H:%M", time.localtime(time.time() + remain_secs))
return remain_secs, str(remain_time), finish_time
def __str__(self):
return self.fmt.format(
median=self.median,
avg=self.avg,
global_avg=self.global_avg,
max=self.max,
value=self.value)
class MetricLogger(object):
def __init__(self, delimiter="\t"):
self.meters = defaultdict(SmoothedValue)
self.delimiter = delimiter
def update(self, **kwargs):
for k, v in kwargs.items():
if v is None:
continue
if isinstance(v, torch.Tensor):
v = v.item()
assert isinstance(v, (float, int))
self.meters[k].update(v)
def __getattr__(self, attr):
if attr in self.meters:
return self.meters[attr]
if attr in self.__dict__:
return self.__dict__[attr]
raise AttributeError("'{}' object has no attribute '{}'".format(
type(self).__name__, attr))
def __str__(self):
loss_str = []
for name, meter in self.meters.items():
loss_str.append(
"{}: {}".format(name, str(meter))
)
return self.delimiter.join(loss_str)
def synchronize_between_processes(self):
for meter in self.meters.values():
meter.synchronize_between_processes()
def add_meter(self, name, meter):
self.meters[name] = meter
def log_every(self, max_iters, itrt, print_freq, header=None):
print_iters = set(np.linspace(0, max_iters-1, print_freq, dtype=int).tolist())
if not header:
header = ''
start_time = time.time()
end = time.time()
self.iter_time = SmoothedValue(fmt='{avg:.4f}')
self.data_time = SmoothedValue(fmt='{avg:.4f}')
space_fmt = ':' + str(len(str(max_iters))) + 'd'
log_msg = [
header,
'[{0' + space_fmt + '}/{1}]',
'eta: {eta}',
'{meters}',
'time: {time}',
'data: {data}'
]
log_msg = self.delimiter.join(log_msg)
if isinstance(itrt, Iterator) and not hasattr(itrt, 'preload') and not hasattr(itrt, 'set_epoch'):
for i in range(max_iters):
obj = next(itrt)
self.data_time.update(time.time() - end)
yield obj
self.iter_time.update(time.time() - end)
if i in print_iters:
eta_seconds = self.iter_time.global_avg * (max_iters - i)
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
print(log_msg.format(
i, max_iters, eta=eta_string,
meters=str(self),
time=str(self.iter_time), data=str(self.data_time)))
end = time.time()
else:
for i, obj in enumerate(itrt):
self.data_time.update(time.time() - end)
yield obj
self.iter_time.update(time.time() - end)
if i in print_iters:
eta_seconds = self.iter_time.global_avg * (max_iters - i)
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
print(log_msg.format(
i, max_iters, eta=eta_string,
meters=str(self),
time=str(self.iter_time), data=str(self.data_time)))
end = time.time()
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print('{} Total time: {} ({:.3f} s / it)'.format(
header, total_time_str, total_time / max_iters))
def _set_print_only_on_master_proc(is_master):
import builtins as __builtin__
builtin_print = __builtin__.print
def prt(msg, *args, **kwargs):
force = kwargs.pop('force', False)
clean = kwargs.pop('clean', False)
deeper = kwargs.pop('deeper', False)
if is_master or force:
if not clean:
f_back = sys._getframe().f_back
if deeper and f_back.f_back is not None:
f_back = f_back.f_back
file_desc = f'{f_back.f_code.co_filename:24s}'[-24:]
msg = f'{time_str()} ({file_desc}, line{f_back.f_lineno:-4d})=> {msg}'
builtin_print(msg, *args, **kwargs)
__builtin__.print = prt
class _SyncPrintToFile(object):
def __init__(self, exp_dir, stdout=True):
self.terminal = sys.stdout if stdout else sys.stderr
fname = os.path.join(exp_dir, 'stdout.txt' if stdout else 'stderr.txt')
self.log = open(fname, 'w')
self.log.flush()
def write(self, message):
self.terminal.write(message)
self.log.write(message)
self.log.flush()
def flush(self):
self.terminal.flush()
self.log.flush()

@ -0,0 +1,160 @@
# Copyright (c) ByteDance, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
# This file is basically a copy to: https://github.com/rwightman/pytorch-image-models/blob/v0.5.4/timm/optim/lamb.py
""" PyTorch Lamb optimizer w/ behaviour similar to NVIDIA FusedLamb
This optimizer code was adapted from the following (starting with latest)
* https://github.com/HabanaAI/Model-References/blob/2b435114fe8e31f159b1d3063b8280ae37af7423/PyTorch/nlp/bert/pretraining/lamb.py
* https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/LanguageModeling/Transformer-XL/pytorch/lamb.py
* https://github.com/cybertronai/pytorch-lamb
Use FusedLamb if you can (GPU). The reason for including this variant of Lamb is to have a version that is
similar in behaviour to APEX FusedLamb if you aren't using NVIDIA GPUs or cannot install/use APEX.
In addition to some cleanup, this Lamb impl has been modified to support PyTorch XLA and has been tested on TPU.
Original copyrights for above sources are below.
Modifications Copyright 2021 Ross Wightman
"""
import math
import torch
from torch.optim.optimizer import Optimizer
class TimmLAMB(Optimizer):
"""Implements a pure pytorch variant of FuseLAMB (NvLamb variant) optimizer from apex.optimizers.FusedLAMB
reference: https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/LanguageModeling/Transformer-XL/pytorch/lamb.py
LAMB was proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_.
Arguments:
params (iterable): iterable of parameters to optimize or dicts defining parameter groups.
lr (float, optional): learning rate. (default: 1e-3)
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its norm. (default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve
numerical stability. (default: 1e-8)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
grad_averaging (bool, optional): whether apply (1-beta2) to grad when
calculating running averages of gradient. (default: True)
max_grad_norm (float, optional): value used to clip global grad norm (default: 1.0)
trust_clip (bool): enable LAMBC trust ratio clipping (default: False)
always_adapt (boolean, optional): Apply adaptive learning rate to 0.0
weight decay parameter (default: False)
.. _Large Batch Optimization for Deep Learning - Training BERT in 76 minutes:
https://arxiv.org/abs/1904.00962
.. _On the Convergence of Adam and Beyond:
https://openreview.net/forum?id=ryQu7f-RZ
"""
def __init__(
self, params, lr=1e-3, bias_correction=True, betas=(0.9, 0.999), eps=1e-6,
weight_decay=0.01, grad_averaging=True, max_grad_norm=2.0, trust_clip=False, always_adapt=False):
defaults = dict(
lr=lr, bias_correction=bias_correction, betas=betas, eps=eps, weight_decay=weight_decay,
grad_averaging=grad_averaging, max_grad_norm=max_grad_norm,
trust_clip=trust_clip, always_adapt=always_adapt)
super().__init__(params, defaults)
print(f'[lamb1] max_grad_norm={max_grad_norm}')
self.global_grad_norm = 0
@torch.no_grad()
def step(self, closure=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
device = self.param_groups[0]['params'][0].device
one_tensor = torch.tensor(1.0, device=device) # because torch.where doesn't handle scalars correctly
global_grad_norm = torch.zeros(1, device=device)
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
grad = p.grad
if grad.is_sparse:
raise RuntimeError('Lamb does not support sparse gradients, consider SparseAdam instad.')
global_grad_norm.add_(grad.pow(2).sum())
global_grad_norm = torch.sqrt(global_grad_norm)
self.global_grad_norm = global_grad_norm.item()
max_grad_norm = torch.tensor(self.defaults['max_grad_norm'], device=device)
clip_global_grad_norm = 1 / torch.where(
global_grad_norm > max_grad_norm,
global_grad_norm / max_grad_norm,
one_tensor)
for group in self.param_groups:
bias_correction = 1 if group['bias_correction'] else 0
beta1, beta2 = group['betas']
grad_averaging = 1 if group['grad_averaging'] else 0
beta3 = 1 - beta1 if grad_averaging else 1.0
# assume same step across group now to simplify things
# per parameter step can be easily support by making it tensor, or pass list into kernel
if 'step' in group:
group['step'] += 1
else:
group['step'] = 1
if bias_correction:
bias_correction1 = 1 - beta1 ** group['step']
bias_correction2 = 1 - beta2 ** group['step']
else:
bias_correction1, bias_correction2 = 1.0, 1.0
for p in group['params']:
if p.grad is None:
continue
grad = p.grad.mul_(clip_global_grad_norm)
state = self.state[p]
# State initialization
if len(state) == 0:
# Exponential moving average of gradient valuesa
state['exp_avg'] = torch.zeros_like(p)
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(p)
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
# Decay the first and second moment running average coefficient
exp_avg.mul_(beta1).add_(grad, alpha=beta3) # m_t
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) # v_t
denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
update = (exp_avg / bias_correction1).div_(denom)
weight_decay = group['weight_decay']
if weight_decay != 0:
update.add_(p, alpha=weight_decay)
if weight_decay != 0 or group['always_adapt']:
# Layer-wise LR adaptation. By default, skip adaptation on parameters that are
# excluded from weight decay, unless always_adapt == True, then always enabled.
w_norm = p.norm(2.0)
g_norm = update.norm(2.0)
# FIXME nested where required since logical and/or not working in PT XLA
trust_ratio = torch.where(
w_norm > 0,
torch.where(g_norm > 0, w_norm / g_norm, one_tensor),
one_tensor,
)
if group['trust_clip']:
# LAMBC trust clipping, upper bound fixed at one
trust_ratio = torch.minimum(trust_ratio, one_tensor)
update.mul_(trust_ratio)
p.add_(update, alpha=-group['lr'])
return loss
Loading…
Cancel
Save