commit
6fda3bf645
18 changed files with 2876 additions and 0 deletions
@ -0,0 +1,7 @@ |
||||
*.swp |
||||
**/__pycache__/** |
||||
.idea/* |
||||
ckpt/ |
||||
*.pth |
||||
*.log |
||||
*.txt |
@ -0,0 +1,395 @@ |
||||
Attribution 4.0 International |
||||
|
||||
======================================================================= |
||||
|
||||
Creative Commons Corporation ("Creative Commons") is not a law firm and |
||||
does not provide legal services or legal advice. Distribution of |
||||
Creative Commons public licenses does not create a lawyer-client or |
||||
other relationship. Creative Commons makes its licenses and related |
||||
information available on an "as-is" basis. Creative Commons gives no |
||||
warranties regarding its licenses, any material licensed under their |
||||
terms and conditions, or any related information. Creative Commons |
||||
disclaims all liability for damages resulting from their use to the |
||||
fullest extent possible. |
||||
|
||||
Using Creative Commons Public Licenses |
||||
|
||||
Creative Commons public licenses provide a standard set of terms and |
||||
conditions that creators and other rights holders may use to share |
||||
original works of authorship and other material subject to copyright |
||||
and certain other rights specified in the public license below. The |
||||
following considerations are for informational purposes only, are not |
||||
exhaustive, and do not form part of our licenses. |
||||
|
||||
Considerations for licensors: Our public licenses are |
||||
intended for use by those authorized to give the public |
||||
permission to use material in ways otherwise restricted by |
||||
copyright and certain other rights. Our licenses are |
||||
irrevocable. Licensors should read and understand the terms |
||||
and conditions of the license they choose before applying it. |
||||
Licensors should also secure all rights necessary before |
||||
applying our licenses so that the public can reuse the |
||||
material as expected. Licensors should clearly mark any |
||||
material not subject to the license. This includes other CC- |
||||
licensed material, or material used under an exception or |
||||
limitation to copyright. More considerations for licensors: |
||||
wiki.creativecommons.org/Considerations_for_licensors |
||||
|
||||
Considerations for the public: By using one of our public |
||||
licenses, a licensor grants the public permission to use the |
||||
licensed material under specified terms and conditions. If |
||||
the licensor's permission is not necessary for any reason--for |
||||
example, because of any applicable exception or limitation to |
||||
copyright--then that use is not regulated by the license. Our |
||||
licenses grant only permissions under copyright and certain |
||||
other rights that a licensor has authority to grant. Use of |
||||
the licensed material may still be restricted for other |
||||
reasons, including because others have copyright or other |
||||
rights in the material. A licensor may make special requests, |
||||
such as asking that all changes be marked or described. |
||||
Although not required by our licenses, you are encouraged to |
||||
respect those requests where reasonable. More considerations |
||||
for the public: |
||||
wiki.creativecommons.org/Considerations_for_licensees |
||||
|
||||
======================================================================= |
||||
|
||||
Creative Commons Attribution 4.0 International Public License |
||||
|
||||
By exercising the Licensed Rights (defined below), You accept and agree |
||||
to be bound by the terms and conditions of this Creative Commons |
||||
Attribution 4.0 International Public License ("Public License"). To the |
||||
extent this Public License may be interpreted as a contract, You are |
||||
granted the Licensed Rights in consideration of Your acceptance of |
||||
these terms and conditions, and the Licensor grants You such rights in |
||||
consideration of benefits the Licensor receives from making the |
||||
Licensed Material available under these terms and conditions. |
||||
|
||||
|
||||
Section 1 -- Definitions. |
||||
|
||||
a. Adapted Material means material subject to Copyright and Similar |
||||
Rights that is derived from or based upon the Licensed Material |
||||
and in which the Licensed Material is translated, altered, |
||||
arranged, transformed, or otherwise modified in a manner requiring |
||||
permission under the Copyright and Similar Rights held by the |
||||
Licensor. For purposes of this Public License, where the Licensed |
||||
Material is a musical work, performance, or sound recording, |
||||
Adapted Material is always produced where the Licensed Material is |
||||
synched in timed relation with a moving image. |
||||
|
||||
b. Adapter's License means the license You apply to Your Copyright |
||||
and Similar Rights in Your contributions to Adapted Material in |
||||
accordance with the terms and conditions of this Public License. |
||||
|
||||
c. Copyright and Similar Rights means copyright and/or similar rights |
||||
closely related to copyright including, without limitation, |
||||
performance, broadcast, sound recording, and Sui Generis Database |
||||
Rights, without regard to how the rights are labeled or |
||||
categorized. For purposes of this Public License, the rights |
||||
specified in Section 2(b)(1)-(2) are not Copyright and Similar |
||||
Rights. |
||||
|
||||
d. Effective Technological Measures means those measures that, in the |
||||
absence of proper authority, may not be circumvented under laws |
||||
fulfilling obligations under Article 11 of the WIPO Copyright |
||||
Treaty adopted on December 20, 1996, and/or similar international |
||||
agreements. |
||||
|
||||
e. Exceptions and Limitations means fair use, fair dealing, and/or |
||||
any other exception or limitation to Copyright and Similar Rights |
||||
that applies to Your use of the Licensed Material. |
||||
|
||||
f. Licensed Material means the artistic or literary work, database, |
||||
or other material to which the Licensor applied this Public |
||||
License. |
||||
|
||||
g. Licensed Rights means the rights granted to You subject to the |
||||
terms and conditions of this Public License, which are limited to |
||||
all Copyright and Similar Rights that apply to Your use of the |
||||
Licensed Material and that the Licensor has authority to license. |
||||
|
||||
h. Licensor means the individual(s) or entity(ies) granting rights |
||||
under this Public License. |
||||
|
||||
i. Share means to provide material to the public by any means or |
||||
process that requires permission under the Licensed Rights, such |
||||
as reproduction, public display, public performance, distribution, |
||||
dissemination, communication, or importation, and to make material |
||||
available to the public including in ways that members of the |
||||
public may access the material from a place and at a time |
||||
individually chosen by them. |
||||
|
||||
j. Sui Generis Database Rights means rights other than copyright |
||||
resulting from Directive 96/9/EC of the European Parliament and of |
||||
the Council of 11 March 1996 on the legal protection of databases, |
||||
as amended and/or succeeded, as well as other essentially |
||||
equivalent rights anywhere in the world. |
||||
|
||||
k. You means the individual or entity exercising the Licensed Rights |
||||
under this Public License. Your has a corresponding meaning. |
||||
|
||||
|
||||
Section 2 -- Scope. |
||||
|
||||
a. License grant. |
||||
|
||||
1. Subject to the terms and conditions of this Public License, |
||||
the Licensor hereby grants You a worldwide, royalty-free, |
||||
non-sublicensable, non-exclusive, irrevocable license to |
||||
exercise the Licensed Rights in the Licensed Material to: |
||||
|
||||
a. reproduce and Share the Licensed Material, in whole or |
||||
in part; and |
||||
|
||||
b. produce, reproduce, and Share Adapted Material. |
||||
|
||||
2. Exceptions and Limitations. For the avoidance of doubt, where |
||||
Exceptions and Limitations apply to Your use, this Public |
||||
License does not apply, and You do not need to comply with |
||||
its terms and conditions. |
||||
|
||||
3. Term. The term of this Public License is specified in Section |
||||
6(a). |
||||
|
||||
4. Media and formats; technical modifications allowed. The |
||||
Licensor authorizes You to exercise the Licensed Rights in |
||||
all media and formats whether now known or hereafter created, |
||||
and to make technical modifications necessary to do so. The |
||||
Licensor waives and/or agrees not to assert any right or |
||||
authority to forbid You from making technical modifications |
||||
necessary to exercise the Licensed Rights, including |
||||
technical modifications necessary to circumvent Effective |
||||
Technological Measures. For purposes of this Public License, |
||||
simply making modifications authorized by this Section 2(a) |
||||
(4) never produces Adapted Material. |
||||
|
||||
5. Downstream recipients. |
||||
|
||||
a. Offer from the Licensor -- Licensed Material. Every |
||||
recipient of the Licensed Material automatically |
||||
receives an offer from the Licensor to exercise the |
||||
Licensed Rights under the terms and conditions of this |
||||
Public License. |
||||
|
||||
b. No downstream restrictions. You may not offer or impose |
||||
any additional or different terms or conditions on, or |
||||
apply any Effective Technological Measures to, the |
||||
Licensed Material if doing so restricts exercise of the |
||||
Licensed Rights by any recipient of the Licensed |
||||
Material. |
||||
|
||||
6. No endorsement. Nothing in this Public License constitutes or |
||||
may be construed as permission to assert or imply that You |
||||
are, or that Your use of the Licensed Material is, connected |
||||
with, or sponsored, endorsed, or granted official status by, |
||||
the Licensor or others designated to receive attribution as |
||||
provided in Section 3(a)(1)(A)(i). |
||||
|
||||
b. Other rights. |
||||
|
||||
1. Moral rights, such as the right of integrity, are not |
||||
licensed under this Public License, nor are publicity, |
||||
privacy, and/or other similar personality rights; however, to |
||||
the extent possible, the Licensor waives and/or agrees not to |
||||
assert any such rights held by the Licensor to the limited |
||||
extent necessary to allow You to exercise the Licensed |
||||
Rights, but not otherwise. |
||||
|
||||
2. Patent and trademark rights are not licensed under this |
||||
Public License. |
||||
|
||||
3. To the extent possible, the Licensor waives any right to |
||||
collect royalties from You for the exercise of the Licensed |
||||
Rights, whether directly or through a collecting society |
||||
under any voluntary or waivable statutory or compulsory |
||||
licensing scheme. In all other cases the Licensor expressly |
||||
reserves any right to collect such royalties. |
||||
|
||||
|
||||
Section 3 -- License Conditions. |
||||
|
||||
Your exercise of the Licensed Rights is expressly made subject to the |
||||
following conditions. |
||||
|
||||
a. Attribution. |
||||
|
||||
1. If You Share the Licensed Material (including in modified |
||||
form), You must: |
||||
|
||||
a. retain the following if it is supplied by the Licensor |
||||
with the Licensed Material: |
||||
|
||||
i. identification of the creator(s) of the Licensed |
||||
Material and any others designated to receive |
||||
attribution, in any reasonable manner requested by |
||||
the Licensor (including by pseudonym if |
||||
designated); |
||||
|
||||
ii. a copyright notice; |
||||
|
||||
iii. a notice that refers to this Public License; |
||||
|
||||
iv. a notice that refers to the disclaimer of |
||||
warranties; |
||||
|
||||
v. a URI or hyperlink to the Licensed Material to the |
||||
extent reasonably practicable; |
||||
|
||||
b. indicate if You modified the Licensed Material and |
||||
retain an indication of any previous modifications; and |
||||
|
||||
c. indicate the Licensed Material is licensed under this |
||||
Public License, and include the text of, or the URI or |
||||
hyperlink to, this Public License. |
||||
|
||||
2. You may satisfy the conditions in Section 3(a)(1) in any |
||||
reasonable manner based on the medium, means, and context in |
||||
which You Share the Licensed Material. For example, it may be |
||||
reasonable to satisfy the conditions by providing a URI or |
||||
hyperlink to a resource that includes the required |
||||
information. |
||||
|
||||
3. If requested by the Licensor, You must remove any of the |
||||
information required by Section 3(a)(1)(A) to the extent |
||||
reasonably practicable. |
||||
|
||||
4. If You Share Adapted Material You produce, the Adapter's |
||||
License You apply must not prevent recipients of the Adapted |
||||
Material from complying with this Public License. |
||||
|
||||
|
||||
Section 4 -- Sui Generis Database Rights. |
||||
|
||||
Where the Licensed Rights include Sui Generis Database Rights that |
||||
apply to Your use of the Licensed Material: |
||||
|
||||
a. for the avoidance of doubt, Section 2(a)(1) grants You the right |
||||
to extract, reuse, reproduce, and Share all or a substantial |
||||
portion of the contents of the database; |
||||
|
||||
b. if You include all or a substantial portion of the database |
||||
contents in a database in which You have Sui Generis Database |
||||
Rights, then the database in which You have Sui Generis Database |
||||
Rights (but not its individual contents) is Adapted Material; and |
||||
|
||||
c. You must comply with the conditions in Section 3(a) if You Share |
||||
all or a substantial portion of the contents of the database. |
||||
|
||||
For the avoidance of doubt, this Section 4 supplements and does not |
||||
replace Your obligations under this Public License where the Licensed |
||||
Rights include other Copyright and Similar Rights. |
||||
|
||||
|
||||
Section 5 -- Disclaimer of Warranties and Limitation of Liability. |
||||
|
||||
a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE |
||||
EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS |
||||
AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF |
||||
ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, |
||||
IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, |
||||
WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR |
||||
PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, |
||||
ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT |
||||
KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT |
||||
ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. |
||||
|
||||
b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE |
||||
TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, |
||||
NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, |
||||
INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, |
||||
COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR |
||||
USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN |
||||
ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR |
||||
DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR |
||||
IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. |
||||
|
||||
c. The disclaimer of warranties and limitation of liability provided |
||||
above shall be interpreted in a manner that, to the extent |
||||
possible, most closely approximates an absolute disclaimer and |
||||
waiver of all liability. |
||||
|
||||
|
||||
Section 6 -- Term and Termination. |
||||
|
||||
a. This Public License applies for the term of the Copyright and |
||||
Similar Rights licensed here. However, if You fail to comply with |
||||
this Public License, then Your rights under this Public License |
||||
terminate automatically. |
||||
|
||||
b. Where Your right to use the Licensed Material has terminated under |
||||
Section 6(a), it reinstates: |
||||
|
||||
1. automatically as of the date the violation is cured, provided |
||||
it is cured within 30 days of Your discovery of the |
||||
violation; or |
||||
|
||||
2. upon express reinstatement by the Licensor. |
||||
|
||||
For the avoidance of doubt, this Section 6(b) does not affect any |
||||
right the Licensor may have to seek remedies for Your violations |
||||
of this Public License. |
||||
|
||||
c. For the avoidance of doubt, the Licensor may also offer the |
||||
Licensed Material under separate terms or conditions or stop |
||||
distributing the Licensed Material at any time; however, doing so |
||||
will not terminate this Public License. |
||||
|
||||
d. Sections 1, 5, 6, 7, and 8 survive termination of this Public |
||||
License. |
||||
|
||||
|
||||
Section 7 -- Other Terms and Conditions. |
||||
|
||||
a. The Licensor shall not be bound by any additional or different |
||||
terms or conditions communicated by You unless expressly agreed. |
||||
|
||||
b. Any arrangements, understandings, or agreements regarding the |
||||
Licensed Material not stated herein are separate from and |
||||
independent of the terms and conditions of this Public License. |
||||
|
||||
|
||||
Section 8 -- Interpretation. |
||||
|
||||
a. For the avoidance of doubt, this Public License does not, and |
||||
shall not be interpreted to, reduce, limit, restrict, or impose |
||||
conditions on any use of the Licensed Material that could lawfully |
||||
be made without permission under this Public License. |
||||
|
||||
b. To the extent possible, if any provision of this Public License is |
||||
deemed unenforceable, it shall be automatically reformed to the |
||||
minimum extent necessary to make it enforceable. If the provision |
||||
cannot be reformed, it shall be severed from this Public License |
||||
without affecting the enforceability of the remaining terms and |
||||
conditions. |
||||
|
||||
c. No term or condition of this Public License will be waived and no |
||||
failure to comply consented to unless expressly agreed to by the |
||||
Licensor. |
||||
|
||||
d. Nothing in this Public License constitutes or may be interpreted |
||||
as a limitation upon, or waiver of, any privileges and immunities |
||||
that apply to the Licensor or You, including from the legal |
||||
processes of any jurisdiction or authority. |
||||
|
||||
|
||||
======================================================================= |
||||
|
||||
Creative Commons is not a party to its public |
||||
licenses. Notwithstanding, Creative Commons may elect to apply one of |
||||
its public licenses to material it publishes and in those instances |
||||
will be considered the “Licensor.” The text of the Creative Commons |
||||
public licenses is dedicated to the public domain under the CC0 Public |
||||
Domain Dedication. Except for the limited purpose of indicating that |
||||
material is shared under a Creative Commons public license or as |
||||
otherwise permitted by the Creative Commons policies published at |
||||
creativecommons.org/policies, Creative Commons does not authorize the |
||||
use of the trademark "Creative Commons" or any other trademark or logo |
||||
of Creative Commons without its prior written consent including, |
||||
without limitation, in connection with any unauthorized modifications |
||||
to any of its public licenses or any other arrangements, |
||||
understandings, or agreements concerning use of licensed material. For |
||||
the avoidance of doubt, this paragraph does not form part of the |
||||
public licenses. |
||||
|
||||
Creative Commons may be contacted at creativecommons.org. |
@ -0,0 +1,112 @@ |
||||
# Copyright (c) ByteDance, Inc. and its affiliates. |
||||
# All rights reserved. |
||||
# |
||||
# This source code is licensed under the license found in the |
||||
# LICENSE file in the root directory of this source tree. |
||||
|
||||
import math |
||||
|
||||
from timm.models.layers import trunc_normal_, DropPath, Mlp |
||||
import torch.nn as nn |
||||
|
||||
from utils.misc import is_pow2n |
||||
|
||||
_BN = None |
||||
|
||||
|
||||
class UNetBlock2x(nn.Module): |
||||
def __init__(self, cin, cout, cmid, last_act=True): |
||||
super().__init__() |
||||
if cmid == 0: |
||||
c_mid = cin |
||||
elif cmid == 1: |
||||
c_mid = (cin + cout) // 2 |
||||
|
||||
self.b = nn.Sequential( |
||||
nn.Conv2d(cin, c_mid, 3, 1, 1, bias=False), _BN(c_mid), nn.ReLU6(inplace=True), |
||||
nn.Conv2d(c_mid, cout, 3, 1, 1, bias=False), _BN(cout), (nn.ReLU6(inplace=True) if last_act else nn.Identity()), |
||||
) |
||||
|
||||
def forward(self, x): |
||||
return self.b(x) |
||||
|
||||
|
||||
class DecoderConv(nn.Module): |
||||
def __init__(self, cin, cout, double, heavy, cmid): |
||||
super().__init__() |
||||
self.up = nn.ConvTranspose2d(cin, cin, kernel_size=4 if double else 2, stride=2, padding=1 if double else 0, bias=True) |
||||
ls = [UNetBlock2x(cin, (cin if i != heavy[1]-1 else cout), cmid=cmid, last_act=i != heavy[1]-1) for i in range(heavy[1])] |
||||
self.conv = nn.Sequential(*ls) |
||||
|
||||
def forward(self, x): |
||||
x = self.up(x) |
||||
return self.conv(x) |
||||
|
||||
|
||||
class LightDecoder(nn.Module): |
||||
def __init__(self, decoder_fea_dim, upsample_ratio, double=False, heavy=None, cmid=0, sbn=False): |
||||
global _BN |
||||
_BN = nn.SyncBatchNorm if sbn else nn.BatchNorm2d |
||||
super().__init__() |
||||
self.fea_dim = decoder_fea_dim |
||||
if heavy is None: |
||||
heavy = [0, 1] |
||||
heavy[1] = max(1, heavy[1]) |
||||
self.double_bool = double |
||||
self.heavy = heavy |
||||
self.cmid = cmid |
||||
self.sbn = sbn |
||||
|
||||
assert is_pow2n(upsample_ratio) |
||||
n = round(math.log2(upsample_ratio)) |
||||
channels = [self.fea_dim // 2**i for i in range(n+1)] |
||||
self.dec = nn.ModuleList([ |
||||
DecoderConv(cin, cout, double, heavy, cmid) for (cin, cout) in zip(channels[:-1], channels[1:]) |
||||
]) |
||||
self.proj = nn.Conv2d(channels[-1], 3, kernel_size=1, stride=1, bias=True) |
||||
|
||||
self.initialize() |
||||
|
||||
def forward(self, to_dec): |
||||
x = 0 |
||||
for i, d in enumerate(self.dec): |
||||
if i < len(to_dec) and to_dec[i] is not None: |
||||
x = x + to_dec[i] |
||||
x = self.dec[i](x) |
||||
return self.proj(x) |
||||
|
||||
def num_para(self): |
||||
tot = sum(p.numel() for p in self.parameters()) |
||||
|
||||
para1 = para2 = 0 |
||||
for m in self.dec.modules(): |
||||
if isinstance(m, nn.ConvTranspose2d): |
||||
para1 += sum(p.numel() for p in m.parameters()) |
||||
elif isinstance(m, nn.Conv2d): |
||||
para2 += sum(p.numel() for p in m.parameters()) |
||||
return f'#para: {tot/1e6:.2f} (dconv={para1/1e6:.2f}, conv={para2/1e6:.2f}, ot={(tot-para1-para2)/1e6:.2f})' |
||||
|
||||
def extra_repr(self) -> str: |
||||
return f'fea_dim={self.fea_dim}, dbl={self.double_bool}, heavy={self.heavy}, cmid={self.cmid}, sbn={self.sbn}' |
||||
|
||||
def initialize(self): |
||||
for m in self.modules(): |
||||
if isinstance(m, nn.Linear): |
||||
trunc_normal_(m.weight, std=.02) |
||||
if m.bias is not None: |
||||
nn.init.constant_(m.bias, 0) |
||||
elif isinstance(m, nn.Embedding): |
||||
trunc_normal_(m.weight, std=.02) |
||||
if m.padding_idx is not None: |
||||
m.weight.data[m.padding_idx].zero_() |
||||
elif isinstance(m, (nn.LayerNorm, nn.BatchNorm1d, nn.BatchNorm2d, nn.SyncBatchNorm)): |
||||
nn.init.constant_(m.bias, 0) |
||||
nn.init.constant_(m.weight, 1.0) |
||||
elif isinstance(m, nn.Conv2d): |
||||
trunc_normal_(m.weight, std=.02) |
||||
if m.bias is not None: |
||||
nn.init.constant_(m.bias, 0) |
||||
elif isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)): |
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') |
||||
if m.bias is not None: |
||||
nn.init.constant_(m.bias, 0.) |
@ -0,0 +1,142 @@ |
||||
# Copyright (c) ByteDance, Inc. and its affiliates. |
||||
# All rights reserved. |
||||
# |
||||
# This source code is licensed under the license found in the |
||||
# LICENSE file in the root directory of this source tree. |
||||
|
||||
import functools |
||||
import os |
||||
from typing import List |
||||
from typing import Union |
||||
|
||||
import torch |
||||
import torch.distributed as tdist |
||||
import torch.multiprocessing as mp |
||||
|
||||
__rank, __local_rank, __world_size, __device = 0, 0, 1, 'cpu' |
||||
__initialized = False |
||||
|
||||
|
||||
def initialized(): |
||||
return __initialized |
||||
|
||||
|
||||
def initialize(backend='nccl'): |
||||
# ref: https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py#L29 |
||||
if mp.get_start_method(allow_none=True) is None: |
||||
mp.set_start_method('spawn') |
||||
global_rank, num_gpus = int(os.environ.get('RANK', 'error')), torch.cuda.device_count() |
||||
local_rank = global_rank % num_gpus |
||||
torch.cuda.set_device(local_rank) |
||||
tdist.init_process_group(backend=backend) # 不要 init_method='env://' |
||||
|
||||
global __rank, __local_rank, __world_size, __device, __initialized |
||||
__local_rank = local_rank |
||||
__rank, __world_size = tdist.get_rank(), tdist.get_world_size() |
||||
__device = torch.empty(1).cuda().device |
||||
__initialized = True |
||||
|
||||
assert tdist.is_initialized(), 'torch.distributed is not initialized!' |
||||
|
||||
|
||||
def get_rank(): |
||||
return __rank |
||||
|
||||
|
||||
def get_local_rank(): |
||||
return __local_rank |
||||
|
||||
|
||||
def get_world_size(): |
||||
return __world_size |
||||
|
||||
|
||||
def get_device(): |
||||
return __device |
||||
|
||||
|
||||
def is_master(): |
||||
return __rank == 0 |
||||
|
||||
|
||||
def is_local_master(): |
||||
return __local_rank == 0 |
||||
|
||||
|
||||
def parallelize(net, syncbn=False): |
||||
if syncbn: |
||||
net = torch.nn.SyncBatchNorm.convert_sync_batchnorm(net) |
||||
net = net.cuda() |
||||
net = torch.nn.parallel.DistributedDataParallel(net, device_ids=[get_local_rank()], find_unused_parameters=False, broadcast_buffers=False) |
||||
return net |
||||
|
||||
|
||||
def new_group(ranks: List[int]): |
||||
return tdist.new_group(ranks=ranks) |
||||
|
||||
|
||||
def barrier(): |
||||
tdist.barrier() |
||||
|
||||
|
||||
def allreduce(t: torch.Tensor) -> None: |
||||
if not t.is_cuda: |
||||
cu = t.detach().cuda() |
||||
tdist.all_reduce(cu) |
||||
t.copy_(cu.cpu()) |
||||
else: |
||||
tdist.all_reduce(t) |
||||
|
||||
|
||||
def allgather(t: torch.Tensor, cat=True) -> Union[List[torch.Tensor], torch.Tensor]: |
||||
if not t.is_cuda: |
||||
t = t.cuda() |
||||
ls = [torch.empty_like(t) for _ in range(__world_size)] |
||||
tdist.all_gather(ls, t) |
||||
if cat: |
||||
ls = torch.cat(ls, dim=0) |
||||
return ls |
||||
|
||||
|
||||
def broadcast(t: torch.Tensor, src_rank) -> None: |
||||
if not t.is_cuda: |
||||
cu = t.detach().cuda() |
||||
tdist.broadcast(cu, src=src_rank) |
||||
t.copy_(cu.cpu()) |
||||
else: |
||||
tdist.broadcast(t, src=src_rank) |
||||
|
||||
|
||||
def dist_fmt_vals(val, fmt: Union[str, None] = '%.2f') -> Union[torch.Tensor, List]: |
||||
ts = torch.zeros(__world_size) |
||||
ts[__rank] = val |
||||
allreduce(ts) |
||||
if fmt is None: |
||||
return ts |
||||
return [fmt % v for v in ts.cpu().numpy().tolist()] |
||||
|
||||
|
||||
def master_only(func): |
||||
@functools.wraps(func) |
||||
def wrapper(*args, **kwargs): |
||||
force = kwargs.pop('force', False) |
||||
if force or is_master(): |
||||
ret = func(*args, **kwargs) |
||||
else: |
||||
ret = None |
||||
barrier() |
||||
return ret |
||||
return wrapper |
||||
|
||||
|
||||
def local_master_only(func): |
||||
@functools.wraps(func) |
||||
def wrapper(*args, **kwargs): |
||||
force = kwargs.pop('force', False) |
||||
if force or is_local_master(): |
||||
ret = func(*args, **kwargs) |
||||
else: |
||||
ret = None |
||||
barrier() |
||||
return ret |
||||
return wrapper |
@ -0,0 +1,207 @@ |
||||
# Copyright (c) ByteDance, Inc. and its affiliates. |
||||
# All rights reserved. |
||||
# |
||||
# This source code is licensed under the license found in the |
||||
# LICENSE file in the root directory of this source tree. |
||||
|
||||
import torch |
||||
import torch.nn as nn |
||||
from timm.models.layers import DropPath |
||||
|
||||
|
||||
_cur_active: torch.Tensor = None # B1ff |
||||
def _get_active_ex_or_ii(H, returning_active_ex=True): |
||||
downsample_raito = H // _cur_active.shape[-1] |
||||
active_ex = _cur_active.repeat_interleave(downsample_raito, 2).repeat_interleave(downsample_raito, 3) |
||||
return active_ex if returning_active_ex else active_ex.squeeze(1).nonzero(as_tuple=True) # ii: bi, hi, wi |
||||
|
||||
|
||||
def sp_conv_forward(self, x: torch.Tensor): |
||||
x = super(type(self), self).forward(x) |
||||
x *= _get_active_ex_or_ii(H=x.shape[2], returning_active_ex=True) # (BCHW) *= (B1HW) |
||||
return x |
||||
|
||||
|
||||
def sp_bn_forward(self, x: torch.Tensor): |
||||
ii = _get_active_ex_or_ii(H=x.shape[2], returning_active_ex=False) |
||||
|
||||
bhwc = x.permute(0, 2, 3, 1) |
||||
nc = bhwc[ii] |
||||
nc = super(type(self), self).forward(nc) # BN1d forward |
||||
|
||||
bchw = torch.zeros_like(bhwc) |
||||
bchw[ii] = nc |
||||
bchw = bchw.permute(0, 3, 1, 2) |
||||
return bchw |
||||
|
||||
|
||||
class SparseConv2d(nn.Conv2d): |
||||
forward = sp_conv_forward |
||||
|
||||
|
||||
class SparseMaxPooling(nn.MaxPool2d): |
||||
forward = sp_conv_forward |
||||
|
||||
|
||||
class SparseAvgPooling(nn.AvgPool2d): |
||||
forward = sp_conv_forward |
||||
|
||||
|
||||
class SparseBatchNorm2d(nn.BatchNorm1d): |
||||
forward = sp_bn_forward |
||||
|
||||
|
||||
class SparseSyncBatchNorm2d(nn.SyncBatchNorm): |
||||
forward = sp_bn_forward |
||||
|
||||
|
||||
class SparseConvNeXtLayerNorm(nn.LayerNorm): |
||||
r""" LayerNorm that supports two data formats: channels_last (default) or channels_first. |
||||
The ordering of the dimensions in the inputs. channels_last corresponds to inputs with |
||||
shape (batch_size, height, width, channels) while channels_first corresponds to inputs |
||||
with shape (batch_size, channels, height, width). |
||||
""" |
||||
|
||||
def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last", sparse=True): |
||||
if data_format not in ["channels_last", "channels_first"]: |
||||
raise NotImplementedError |
||||
super().__init__(normalized_shape, eps, elementwise_affine=True) |
||||
self.data_format = data_format |
||||
self.sparse = sparse |
||||
|
||||
def forward(self, x): |
||||
if x.ndim == 4: # BHWC |
||||
if self.data_format == "channels_last": |
||||
if self.sparse: |
||||
ii = _get_active_ex_or_ii(H=x.shape[1], returning_active_ex=False) |
||||
nc = x[ii] |
||||
nc = super(SparseConvNeXtLayerNorm, self).forward(nc) |
||||
|
||||
x = torch.zeros_like(x) |
||||
x[ii] = nc |
||||
return x |
||||
else: |
||||
return super(SparseConvNeXtLayerNorm, self).forward(x) |
||||
else: # channels_first |
||||
if self.sparse: |
||||
ii = _get_active_ex_or_ii(H=x.shape[2], returning_active_ex=False) |
||||
bhwc = x.permute(0, 2, 3, 1) |
||||
nc = bhwc[ii] |
||||
nc = super(SparseConvNeXtLayerNorm, self).forward(nc) |
||||
|
||||
x = torch.zeros_like(bhwc) |
||||
x[ii] = nc |
||||
return x.permute(0, 3, 1, 2) |
||||
else: |
||||
u = x.mean(1, keepdim=True) |
||||
s = (x - u).pow(2).mean(1, keepdim=True) |
||||
x = (x - u) / torch.sqrt(s + self.eps) |
||||
x = self.weight[:, None, None] * x + self.bias[:, None, None] |
||||
return x |
||||
else: # BLC or BC |
||||
if self.sparse: |
||||
raise NotImplementedError |
||||
else: |
||||
return super(SparseConvNeXtLayerNorm, self).forward(x) |
||||
|
||||
def __repr__(self): |
||||
return super(SparseConvNeXtLayerNorm, self).__repr__()[:-1] + f', ch={self.data_format.split("_")[-1]}, sp={self.sparse})' |
||||
|
||||
|
||||
class SparseConvNeXtBlock(nn.Module): |
||||
r""" ConvNeXt Block. There are two equivalent implementations: |
||||
(1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W) |
||||
(2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back |
||||
We use (2) as we find it slightly faster in PyTorch |
||||
|
||||
Args: |
||||
dim (int): Number of input channels. |
||||
drop_path (float): Stochastic depth rate. Default: 0.0 |
||||
layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. |
||||
""" |
||||
|
||||
def __init__(self, dim, drop_path=0., layer_scale_init_value=1e-6, sparse=True, ks=7): |
||||
super().__init__() |
||||
self.dwconv = nn.Conv2d(dim, dim, kernel_size=ks, padding=ks//2, groups=dim) # depthwise conv |
||||
self.norm = SparseConvNeXtLayerNorm(dim, eps=1e-6, sparse=sparse) |
||||
self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers |
||||
self.act = nn.GELU() |
||||
self.pwconv2 = nn.Linear(4 * dim, dim) |
||||
self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)), |
||||
requires_grad=True) if layer_scale_init_value > 0 else None |
||||
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() |
||||
self.sparse = sparse |
||||
|
||||
def forward(self, x): |
||||
input = x |
||||
x = self.dwconv(x) |
||||
x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) |
||||
x = self.norm(x) |
||||
x = self.pwconv1(x) |
||||
x = self.act(x) |
||||
x = self.pwconv2(x) |
||||
if self.gamma is not None: |
||||
x = self.gamma * x |
||||
x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W) |
||||
|
||||
if self.sparse: |
||||
x *= _get_active_ex_or_ii(H=x.shape[2], returning_active_ex=True) |
||||
|
||||
x = input + self.drop_path(x) |
||||
return x |
||||
|
||||
def __repr__(self): |
||||
return super(SparseConvNeXtBlock, self).__repr__()[:-1] + f', sp={self.sparse})' |
||||
|
||||
|
||||
class SparseEncoder(nn.Module): |
||||
def __init__(self, cnn, input_size, downsample_raito, encoder_fea_dim, verbose=False, sbn=False): |
||||
super(SparseEncoder, self).__init__() |
||||
self.sp_cnn = SparseEncoder.dense_model_to_sparse(m=cnn, verbose=verbose, sbn=sbn) |
||||
self.input_size, self.downsample_raito, self.fea_dim = input_size, downsample_raito, encoder_fea_dim |
||||
|
||||
@staticmethod |
||||
def dense_model_to_sparse(m: nn.Module, verbose=False, sbn=False): |
||||
oup = m |
||||
if isinstance(m, nn.Conv2d): |
||||
m: nn.Conv2d |
||||
bias = m.bias is not None |
||||
oup = SparseConv2d( |
||||
m.in_channels, m.out_channels, |
||||
kernel_size=m.kernel_size, stride=m.stride, padding=m.padding, |
||||
dilation=m.dilation, groups=m.groups, bias=bias, padding_mode=m.padding_mode, |
||||
) |
||||
oup.weight.data.copy_(m.weight.data) |
||||
if bias: |
||||
oup.bias.data.copy_(m.bias.data) |
||||
elif isinstance(m, nn.MaxPool2d): |
||||
m: nn.MaxPool2d |
||||
oup = SparseMaxPooling(m.kernel_size, stride=m.stride, padding=m.padding, dilation=m.dilation, return_indices=m.return_indices, ceil_mode=m.ceil_mode) |
||||
elif isinstance(m, nn.AvgPool2d): |
||||
m: nn.AvgPool2d |
||||
oup = SparseAvgPooling(m.kernel_size, m.stride, m.padding, ceil_mode=m.ceil_mode, count_include_pad=m.count_include_pad, divisor_override=m.divisor_override) |
||||
elif isinstance(m, (nn.BatchNorm2d, nn.SyncBatchNorm)): |
||||
m: nn.BatchNorm2d |
||||
oup = (SparseSyncBatchNorm2d if sbn else SparseBatchNorm2d)(m.weight.shape[0], eps=m.eps, momentum=m.momentum, affine=m.affine, track_running_stats=m.track_running_stats) |
||||
oup.weight.data.copy_(m.weight.data) |
||||
oup.bias.data.copy_(m.bias.data) |
||||
oup.running_mean.data.copy_(m.running_mean.data) |
||||
oup.running_var.data.copy_(m.running_var.data) |
||||
oup.num_batches_tracked.data.copy_(m.num_batches_tracked.data) |
||||
if hasattr(m, "qconfig"): |
||||
oup.qconfig = m.qconfig |
||||
elif isinstance(m, nn.LayerNorm) and not isinstance(m, SparseConvNeXtLayerNorm): |
||||
m: nn.LayerNorm |
||||
oup = SparseConvNeXtLayerNorm(m.weight.shape[0], eps=m.eps) |
||||
oup.weight.data.copy_(m.weight.data) |
||||
oup.bias.data.copy_(m.bias.data) |
||||
elif isinstance(m, (nn.Conv1d,)): |
||||
raise NotImplementedError |
||||
|
||||
for name, child in m.named_children(): |
||||
oup.add_module(name, SparseEncoder.dense_model_to_sparse(child, verbose=verbose, sbn=sbn)) |
||||
del m |
||||
return oup |
||||
|
||||
def forward(self, x, pyramid): |
||||
return self.sp_cnn(x, pyramid=pyramid) |
@ -0,0 +1,87 @@ |
||||
#!/usr/bin/python3 |
||||
|
||||
# Copyright (c) ByteDance, Inc. and its affiliates. |
||||
# All rights reserved. |
||||
# |
||||
# This source code is licensed under the license found in the |
||||
# LICENSE file in the root directory of this source tree. |
||||
|
||||
import argparse |
||||
import functools |
||||
import os |
||||
import socket |
||||
import subprocess |
||||
import sys |
||||
from typing import List |
||||
|
||||
echo = lambda info: os_system(f'echo "[$(date "+%m-%d-%H:%M:%S")] ({os.path.basename(sys._getframe().f_back.f_code.co_filename)}, line{sys._getframe().f_back.f_lineno})=> {info}"') |
||||
os_system = functools.partial(subprocess.call, shell=True) |
||||
os_system_get_stdout = lambda cmd: subprocess.run(cmd, shell=True, stdout=subprocess.PIPE).stdout.decode('utf-8') |
||||
def os_system_get_stdout_stderr(cmd): |
||||
sp = subprocess.run(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) |
||||
return sp.stdout.decode('utf-8'), sp.stderr.decode('utf-8') |
||||
|
||||
|
||||
def __find_free_port(): |
||||
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) |
||||
# Binding to port 0 will cause the OS to find an available port for us |
||||
sock.bind(("", 0)) |
||||
port = sock.getsockname()[1] |
||||
sock.close() |
||||
# NOTE: there is still a chance the port could be taken by other processes. |
||||
return port |
||||
|
||||
|
||||
if __name__ == '__main__': |
||||
parser = argparse.ArgumentParser(description='PyTorch Distributed Launcher') |
||||
parser.add_argument('--main_py_relpath', type=str, default='main.py', |
||||
help='specify launcher script.') |
||||
|
||||
# distributed environment |
||||
parser.add_argument('--num_nodes', type=int, default=1) |
||||
parser.add_argument('--ngpu_per_node', type=int, default=1) |
||||
parser.add_argument('--node_rank', type=int, default=0, |
||||
help='node rank, ranged from 0 to [dist_num_nodes]-1') |
||||
parser.add_argument('--master_address', type=str, default='128.0.0.0', |
||||
help='master address for distributed communication') |
||||
parser.add_argument('--master_port', type=int, default=30001, |
||||
help='master port for distributed communication') |
||||
|
||||
# other args |
||||
known_args, other_args = parser.parse_known_args() |
||||
other_args: List[str] |
||||
echo(f'[other_args received by launch.py]: {other_args}') |
||||
|
||||
main_args = other_args[-1] |
||||
main_args = '='.join(map(str.strip, main_args.split('='))) |
||||
main_args = main_args.split(' ') |
||||
for i, a in enumerate(main_args): |
||||
if len(a) and '=' not in a: |
||||
main_args[i] = f'{a}=1' |
||||
other_args[-1] = ' '.join(main_args) |
||||
|
||||
echo(f'[final other_args]: {other_args[-1]}') |
||||
|
||||
if known_args.num_nodes > 1: |
||||
os.environ['NPROC_PER_NODE'] = str(known_args.ngpu_per_node) |
||||
cmd = ( |
||||
f'python3 -m torch.distributed.launch' |
||||
f' --nproc_per_node={known_args.ngpu_per_node}' |
||||
f' --nnodes={known_args.num_nodes}' |
||||
f' --node_rank={known_args.node_rank}' |
||||
f' --master_addr={known_args.master_address}' |
||||
f' --master_port={known_args.master_port}' |
||||
f' {known_args.main_py_relpath}' |
||||
f' {" ".join(other_args)}' |
||||
) |
||||
else: |
||||
cmd = ( |
||||
f'python3 -m torch.distributed.launch' |
||||
f' --nproc_per_node={known_args.ngpu_per_node}' |
||||
f' --master_port={known_args.master_port}' |
||||
f' {known_args.main_py_relpath}' |
||||
f' {" ".join(other_args)}' |
||||
) |
||||
|
||||
exit_code = subprocess.call(cmd, shell=True) |
||||
sys.exit(exit_code) |
@ -0,0 +1,156 @@ |
||||
# Copyright (c) ByteDance, Inc. and its affiliates. |
||||
# All rights reserved. |
||||
# |
||||
# This source code is licensed under the license found in the |
||||
# LICENSE file in the root directory of this source tree. |
||||
|
||||
import math |
||||
import sys |
||||
import time |
||||
from functools import partial |
||||
from typing import List |
||||
|
||||
import torch |
||||
from torch.nn.parallel import DistributedDataParallel |
||||
from torch.utils.data import DataLoader |
||||
|
||||
import dist |
||||
import encoder |
||||
from decoder import LightDecoder |
||||
from models import build_sparse_encoder |
||||
from sampler import DistInfiniteBatchSampler, worker_init_fn |
||||
from spark import SparK |
||||
from utils import meta, misc, optim |
||||
from utils.imagenet import build_imagenet |
||||
from utils.lr_control import lr_wd_annealing, get_param_groups |
||||
|
||||
|
||||
def main_pt(): |
||||
args: meta.Args = meta.init_dist_and_get_args() |
||||
print(f'global bs={args.glb_batch_size}, local bs={args.batch_size}') |
||||
print(f'initial args:\n{str(args)}') |
||||
args.log_epoch() |
||||
|
||||
# build data |
||||
print(f'[build data for pre-training] ...\n') |
||||
dataset_train, _ = build_imagenet('pt', args.data_path, args.data_set, args.input_size, eval_crop_pct=None, rrc=args.rrc) |
||||
data_loader_train = DataLoader( |
||||
dataset=dataset_train, num_workers=args.num_workers, pin_memory=True, |
||||
batch_sampler=DistInfiniteBatchSampler( |
||||
dataset_len=len(dataset_train), glb_batch_size=args.glb_batch_size, seed=args.seed, |
||||
shuffle=True, filling=True, rank=dist.get_rank(), world_size=dist.get_world_size(), |
||||
), worker_init_fn=worker_init_fn |
||||
) |
||||
itrt_train = iter(data_loader_train) |
||||
iters_train = len(data_loader_train) |
||||
print(f'[dataloader] gbs={args.glb_batch_size}, lbs={args.batch_size}, iters_train={iters_train}') |
||||
|
||||
# build models (encoder, decoder, and other components) |
||||
enc: encoder.SparseEncoder = build_sparse_encoder(args.model, input_size=args.input_size, sbn=args.sbn, drop_path_rate=args.dp, verbose=False) |
||||
dec = LightDecoder(args.dec_dim, enc.downsample_raito, double=args.double, heavy=args.hea, cmid=args.cmid, sbn=args.sbn) |
||||
spark = SparK( |
||||
sparse_encoder=enc, dense_decoder=dec, mask_ratio=args.mask, mask_ratio2=args.mask2, uniform=args.uni, |
||||
using_pe=args.pe, pix_norm=args.pn, dense_loss=args.den, loss_l2=args.loss_l2, |
||||
en_de_norm=args.en_de_norm, en_de_lin=args.en_de_lin, sbn=args.sbn, pyramid=args.py, |
||||
) |
||||
print(f'[PT model] model = {spark}\n') |
||||
spark.to(args.device) |
||||
model: DistributedDataParallel = DistributedDataParallel(spark, device_ids=[dist.get_local_rank()], find_unused_parameters=False, broadcast_buffers=False) |
||||
model_without_ddp: SparK = model.module |
||||
|
||||
# build optimizer and lr_scheduler |
||||
param_groups: List[dict] = get_param_groups(model_without_ddp, nowd_keys={'pos_embed', 'mask_token', 'gamma'}, lr_scale=0) |
||||
opt_clz = { |
||||
'sgd': partial(torch.optim.SGD, momentum=0.9, nesterov=True), |
||||
'adamw': partial(torch.optim.AdamW, betas=(0.9, args.ada)), |
||||
'lamb': partial(optim.TimmLAMB, betas=(0.9, args.ada), max_grad_norm=args.clip), |
||||
}[args.opt] |
||||
optimizer = opt_clz(params=param_groups, lr=args.lr, weight_decay=0.0) |
||||
print(f'[optimizer] optimizer({opt_clz}) ={optimizer}\n') |
||||
|
||||
# try to resume |
||||
next_ep, performance_desc = misc.load_checkpoint(args.resume, model_without_ddp, optimizer) if len(args.resume) else (0, '[no performance_desc]') |
||||
if next_ep >= args.ep: |
||||
# load from a complete checkpoint file |
||||
print(f' [*] [PT already done] Min/Last Recon Loss: {performance_desc}') |
||||
else: |
||||
# perform pre-training |
||||
start_time = time.time() |
||||
min_loss = 1e9 |
||||
print(f'[PT start] from ep{next_ep}') |
||||
|
||||
for ep in range(next_ep, args.ep): |
||||
if hasattr(itrt_train, 'set_epoch'): |
||||
itrt_train.set_epoch(ep) |
||||
|
||||
stats, (sec, remain_time, finish_time) = pre_train_one_ep(ep, args, itrt_train, iters_train, model, optimizer) |
||||
last_loss = stats['last_loss'] |
||||
min_loss = min(min_loss, last_loss) |
||||
performance_desc = f'{min_loss:.4f} {last_loss:.4f}' |
||||
print(f' [*] [ep{ep}] Min/Last Recon Loss: {performance_desc}, Remain: {remain_time}, Finish: {finish_time}') |
||||
|
||||
args.cur_phase = 'PT' |
||||
args.cur_ep = f'{ep+1}/{args.ep}' |
||||
args.remain_time, args.finish_time = str(remain_time), str(finish_time) |
||||
args.last_loss = last_loss |
||||
args.log_epoch() |
||||
misc.save_checkpoint(f'ckpt-last.pth', args, ep, performance_desc, model_without_ddp.state_dict(with_config=True), optimizer.state_dict()) |
||||
|
||||
# finish pre-training |
||||
print('\n\n') |
||||
print(f' [*] [PT finished] Min/Last Recon Loss: {performance_desc}, Total Cost: {(time.time() - start_time) / 60 / 60:.1f}h') |
||||
print('\n\n') |
||||
|
||||
misc.save_checkpoint(f'ckpt-final.pth', args, args.ep-1, performance_desc, model_without_ddp.state_dict(with_config=True), optimizer.state_dict()) |
||||
args.remain_time, args.finish_time = '-', time.strftime("%m-%d %H:%M", time.localtime(time.time())) |
||||
args.log_epoch() |
||||
|
||||
|
||||
def pre_train_one_ep(ep, args, itrt_train, iters_train, model: DistributedDataParallel, optimizer): |
||||
model.train() |
||||
me = misc.MetricLogger(delimiter=' ') |
||||
me.add_meter('max_lr', misc.SmoothedValue(window_size=1, fmt='{value:.5f}')) |
||||
header = f'[PT] Epoch: [{ep:3d}/{args.ep}]' |
||||
|
||||
optimizer.zero_grad() |
||||
early_clipping = args.clip > 0 and not hasattr(optimizer, 'global_grad_norm') |
||||
late_clipping = args.clip > 0 and hasattr(optimizer, 'global_grad_norm') |
||||
if early_clipping: |
||||
params_req_grad = [p for p in model.parameters() if p.requires_grad] |
||||
|
||||
# for every batch do: |
||||
for it, (inp, _) in enumerate(me.log_every(iters_train, itrt_train, 3, header)): |
||||
# adjust lr and wd |
||||
g_it = it + ep*iters_train |
||||
min_lr, max_lr, min_wd, max_wd = lr_wd_annealing(optimizer, args.lr, args.wd, args.wde, g_it, args.wp_ep*iters_train, args.ep*iters_train) |
||||
|
||||
# forward and backward |
||||
inp = inp.to(args.device, non_blocking=True) |
||||
SparK.forward |
||||
active_ex, rec, loss = model(inp) |
||||
optimizer.zero_grad() |
||||
loss.backward() |
||||
loss = loss.item() |
||||
if not math.isfinite(loss): |
||||
print(f'[rk{dist.get_rank():02d}] Loss is {loss}, stopping training!', force=True, flush=True) |
||||
sys.exit(-1) |
||||
|
||||
# optimize |
||||
grad_norm = None |
||||
if early_clipping: grad_norm = torch.nn.utils.clip_grad_norm_(params_req_grad, args.clip) |
||||
optimizer.step() |
||||
if late_clipping: grad_norm = optimizer.global_grad_norm |
||||
torch.cuda.synchronize() |
||||
|
||||
# log |
||||
me.update(last_loss=loss) |
||||
me.update(max_lr=max_lr) |
||||
if grad_norm is not None: |
||||
me.update(orig_norm=grad_norm) |
||||
|
||||
me.synchronize_between_processes() |
||||
return {k: meter.global_avg for k, meter in me.meters.items()}, me.iter_time.time_preds((args.ep-1-ep) * (iters_train+10)) |
||||
|
||||
|
||||
if __name__ == '__main__': |
||||
main_pt() |
@ -0,0 +1,83 @@ |
||||
# Copyright (c) ByteDance, Inc. and its affiliates. |
||||
# All rights reserved. |
||||
# |
||||
# This source code is licensed under the license found in the |
||||
# LICENSE file in the root directory of this source tree. |
||||
|
||||
import torch |
||||
from timm import create_model |
||||
from timm.loss import SoftTargetCrossEntropy |
||||
from timm.models.layers import drop |
||||
|
||||
|
||||
from models.convnext import ConvNeXt |
||||
from models.resnet import ResNet |
||||
_import_resnets_for_timm_registration = (ResNet,) |
||||
|
||||
|
||||
# log more |
||||
def _ex_repr(self): |
||||
return ', '.join( |
||||
f'{k}=' + (f'{v:g}' if isinstance(v, float) else str(v)) |
||||
for k, v in vars(self).items() |
||||
if not k.startswith('_') and k != 'training' |
||||
and not isinstance(v, (torch.nn.Module, torch.Tensor)) |
||||
) |
||||
for clz in (torch.nn.CrossEntropyLoss, SoftTargetCrossEntropy, drop.DropPath): |
||||
if hasattr(clz, 'extra_repr'): |
||||
clz.extra_repr = _ex_repr |
||||
else: |
||||
clz.__repr__ = lambda self: f'{type(self).__name__}({_ex_repr(self)})' |
||||
|
||||
|
||||
model_alias_to_fullname = { |
||||
'res50': 'resnet50', |
||||
'res101': 'resnet101', |
||||
'res152': 'resnet152', |
||||
'res200': 'resnet200', |
||||
'cnxS': 'convnext_small', |
||||
'cnxB': 'convnext_base', |
||||
'cnxL': 'convnext_large', |
||||
} |
||||
model_fullname_to_alias = {v: k for k, v in model_alias_to_fullname.items()} |
||||
|
||||
|
||||
pre_train_d = { # default drop_path_rate, num of para, FLOPs, downsample_ratio, num of channel |
||||
'resnet50': [dict(drop_path_rate=0.05), 25.6, 4.1, 32, 2048], |
||||
'resnet101': [dict(drop_path_rate=0.08), 44.5, 7.9, 32, 2048], |
||||
'resnet152': [dict(drop_path_rate=0.10), 60.2, 11.6, 32, 2048], |
||||
'resnet200': [dict(drop_path_rate=0.15), 64.7, 15.1, 32, 2048], |
||||
'convnext_small': [dict(sparse=True, drop_path_rate=0.2), 50.0, 8.7, 32, 768], |
||||
'convnext_base': [dict(sparse=True, drop_path_rate=0.3), 89.0, 15.4, 32, 1024], |
||||
'convnext_large': [dict(sparse=True, drop_path_rate=0.4), 198.0, 34.4, 32, 1536], |
||||
} |
||||
for v in pre_train_d.values(): |
||||
v[0]['pretrained'] = False |
||||
v[0]['num_classes'] = 0 |
||||
v[0]['global_pool'] = '' |
||||
|
||||
|
||||
def build_sparse_encoder(name: str, input_size: int, sbn=False, drop_path_rate=0.0, verbose=False): |
||||
from encoder import SparseEncoder |
||||
|
||||
kwargs, params, flops, downsample_raito, fea_dim = pre_train_d[name] |
||||
if drop_path_rate != 0: |
||||
kwargs['drop_path_rate'] = drop_path_rate |
||||
print(f'[sparse_cnn] model kwargs={kwargs}') |
||||
cnn = create_model(name, **kwargs) |
||||
if hasattr(cnn, 'global_pool'): |
||||
if callable(cnn.global_pool): |
||||
cnn.global_pool = torch.nn.Identity() |
||||
elif isinstance(cnn.global_pool, str): |
||||
cnn.global_pool = '' |
||||
|
||||
if not isinstance(downsample_raito, int) or not isinstance(fea_dim, int): |
||||
with torch.no_grad(): |
||||
cnn.eval() |
||||
o = cnn(torch.rand(1, 3, input_size, input_size)) |
||||
downsample_raito = input_size // o.shape[-1] |
||||
fea_dim = o.shape[1] |
||||
cnn.train() |
||||
print(f'[sparse_cnn] downsample_raito={downsample_raito}, fea_dim={fea_dim}') |
||||
|
||||
return SparseEncoder(cnn, input_size=input_size, downsample_raito=downsample_raito, encoder_fea_dim=fea_dim, verbose=verbose, sbn=sbn) |
@ -0,0 +1,212 @@ |
||||
# Copyright (c) ByteDance, Inc. and its affiliates. |
||||
# All rights reserved. |
||||
# |
||||
# This source code is licensed under the license found in the |
||||
# LICENSE file in the root directory of this source tree. |
||||
# |
||||
# This file is basically a copy to: https://github.com/facebookresearch/ConvNeXt/blob/main/models/convnext.py |
||||
|
||||
import torch |
||||
import torch.nn as nn |
||||
from timm.models.layers import trunc_normal_ |
||||
from timm.models.registry import register_model |
||||
|
||||
from encoder import SparseConvNeXtBlock, SparseConvNeXtLayerNorm |
||||
|
||||
|
||||
class ConvNeXt(nn.Module): |
||||
r""" ConvNeXt |
||||
A PyTorch impl of : `A ConvNet for the 2020s` - |
||||
https://arxiv.org/pdf/2201.03545.pdf |
||||
Args: |
||||
in_chans (int): Number of input image channels. Default: 3 |
||||
num_classes (int): Number of classes for classification head. Default: 1000 |
||||
depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3] |
||||
dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768] |
||||
drop_path_rate (float): Stochastic depth rate. Default: 0. |
||||
layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. |
||||
head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1. |
||||
""" |
||||
|
||||
def __init__(self, in_chans=3, num_classes=1000, |
||||
depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], drop_path_rate=0., |
||||
layer_scale_init_value=1e-6, head_init_scale=1., global_pool='avg', |
||||
sparse=True, |
||||
): |
||||
super().__init__() |
||||
|
||||
self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers |
||||
stem = nn.Sequential( |
||||
nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4), |
||||
SparseConvNeXtLayerNorm(dims[0], eps=1e-6, data_format="channels_first", sparse=sparse) |
||||
) |
||||
self.downsample_layers.append(stem) |
||||
for i in range(3): |
||||
downsample_layer = nn.Sequential( |
||||
SparseConvNeXtLayerNorm(dims[i], eps=1e-6, data_format="channels_first", sparse=sparse), |
||||
nn.Conv2d(dims[i], dims[i + 1], kernel_size=2, stride=2), |
||||
) |
||||
self.downsample_layers.append(downsample_layer) |
||||
|
||||
self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple residual blocks |
||||
self.drop_path_rate = drop_path_rate |
||||
self.layer_scale_init_value = layer_scale_init_value |
||||
dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] |
||||
cur = 0 |
||||
for i in range(4): |
||||
stage = nn.Sequential( |
||||
*[SparseConvNeXtBlock(dim=dims[i], drop_path=dp_rates[cur + j], |
||||
layer_scale_init_value=layer_scale_init_value, sparse=sparse) for j in range(depths[i])] |
||||
) |
||||
self.stages.append(stage) |
||||
cur += depths[i] |
||||
self.depths = depths |
||||
|
||||
self.apply(self._init_weights) |
||||
if num_classes > 0: |
||||
self.norm = SparseConvNeXtLayerNorm(dims[-1], eps=1e-6, sparse=False) # final norm layer for LE/FT; should not be sparse |
||||
self.fc = nn.Linear(dims[-1], num_classes) |
||||
# self.fc.weight.data.mul_(head_init_scale) # todo: perform this outside |
||||
# self.fc.bias.data.mul_(head_init_scale) # todo: perform this outside |
||||
else: |
||||
self.norm = nn.Identity() |
||||
self.fc = nn.Identity() |
||||
|
||||
self.with_pooling = len(global_pool) > 0 |
||||
|
||||
def _init_weights(self, m): |
||||
if isinstance(m, (nn.Conv2d, nn.Linear)): |
||||
trunc_normal_(m.weight, std=.02) |
||||
nn.init.constant_(m.bias, 0) |
||||
|
||||
def forward_features(self, x, pyramid: int): # pyramid: 0, 1, 2, 3, 4 |
||||
ls = [] |
||||
for i in range(4): |
||||
x = self.downsample_layers[i](x) |
||||
x = self.stages[i](x) |
||||
if pyramid: |
||||
ls.append(x) |
||||
|
||||
if pyramid: |
||||
for i in range(len(ls)-pyramid-1, -1, -1): |
||||
del ls[i] |
||||
return [None] * (4 - pyramid) + ls |
||||
else: |
||||
if self.with_pooling: |
||||
x = x.mean([-2, -1]) # global average pooling, (N, C, H, W) -> (N, C) |
||||
return x |
||||
|
||||
def forward(self, x, pyramid=0): |
||||
if pyramid == 0: |
||||
x = self.forward_features(x, pyramid=pyramid) |
||||
x = self.fc(self.norm(x)) |
||||
return x |
||||
else: |
||||
return self.forward_features(x, pyramid=pyramid) |
||||
|
||||
def get_classifier(self): |
||||
return self.fc |
||||
|
||||
def extra_repr(self): |
||||
return f'drop_path_rate={self.drop_path_rate}, layer_scale_init_value={self.layer_scale_init_value:g}' |
||||
|
||||
def get_layer_id_and_scale_exp(self, para_name: str): |
||||
N = 12 if self.depths[-2] > 9 else 6 |
||||
if para_name.startswith("downsample_layers"): |
||||
stage_id = int(para_name.split('.')[1]) |
||||
if stage_id == 0: |
||||
layer_id = 0 |
||||
elif stage_id == 1 or stage_id == 2: |
||||
layer_id = stage_id + 1 |
||||
else: # stage_id == 3: |
||||
layer_id = N |
||||
elif para_name.startswith("stages"): |
||||
stage_id = int(para_name.split('.')[1]) |
||||
block_id = int(para_name.split('.')[2]) |
||||
if stage_id == 0 or stage_id == 1: |
||||
layer_id = stage_id + 1 |
||||
elif stage_id == 2: |
||||
layer_id = 3 + block_id // 3 |
||||
else: # stage_id == 3: |
||||
layer_id = N |
||||
else: |
||||
layer_id = N + 1 # after backbone |
||||
|
||||
return layer_id, N + 1 - layer_id |
||||
|
||||
|
||||
model_urls = { |
||||
"convnext_tiny_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth", |
||||
"convnext_small_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_small_1k_224_ema.pth", |
||||
"convnext_base_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_224_ema.pth", |
||||
"convnext_large_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth", |
||||
"convnext_tiny_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_224.pth", |
||||
"convnext_small_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_224.pth", |
||||
"convnext_base_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_224.pth", |
||||
"convnext_large_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_224.pth", |
||||
"convnext_xlarge_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_224.pth", |
||||
} |
||||
|
||||
|
||||
@register_model |
||||
def convnext_tiny(pretrained=False, in_22k=False, **kwargs): |
||||
model = ConvNeXt(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], **kwargs) |
||||
if pretrained: |
||||
url = model_urls['convnext_tiny_22k'] if in_22k else model_urls['convnext_tiny_1k'] |
||||
checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu", check_hash=True) |
||||
model.load_state_dict(checkpoint["model"]) |
||||
return model |
||||
|
||||
|
||||
@register_model |
||||
def convnext_small(pretrained=False, in_22k=False, **kwargs): |
||||
model = ConvNeXt(depths=[3, 3, 27, 3], dims=[96, 192, 384, 768], **kwargs) |
||||
if pretrained: |
||||
url = model_urls['convnext_small_22k'] if in_22k else model_urls['convnext_small_1k'] |
||||
checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu") |
||||
model.load_state_dict(checkpoint["model"]) |
||||
return model |
||||
|
||||
|
||||
@register_model |
||||
def convnext_base(pretrained=False, in_22k=False, **kwargs): |
||||
model = ConvNeXt(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs) |
||||
if pretrained: |
||||
url = model_urls['convnext_base_22k'] if in_22k else model_urls['convnext_base_1k'] |
||||
checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu") |
||||
model.load_state_dict(checkpoint["model"]) |
||||
return model |
||||
|
||||
|
||||
@register_model |
||||
def convnext_large(pretrained=False, in_22k=False, **kwargs): |
||||
model = ConvNeXt(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], **kwargs) |
||||
if pretrained: |
||||
url = model_urls['convnext_large_22k'] if in_22k else model_urls['convnext_large_1k'] |
||||
checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu") |
||||
model.load_state_dict(checkpoint["model"]) |
||||
return model |
||||
|
||||
|
||||
@register_model |
||||
def convnext_xlarge(pretrained=False, in_22k=False, **kwargs): |
||||
model = ConvNeXt(depths=[3, 3, 27, 3], dims=[256, 512, 1024, 2048], **kwargs) |
||||
if pretrained: |
||||
assert in_22k, "only ImageNet-22K pre-trained ConvNeXt-XL is available; please set in_22k=True" |
||||
url = model_urls['convnext_xlarge_22k'] |
||||
checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu") |
||||
model.load_state_dict(checkpoint["model"]) |
||||
return model |
||||
|
||||
|
||||
if __name__ == '__main__': |
||||
from timm.models import create_model |
||||
|
||||
c = create_model('convnext_small', sparse=False) |
||||
with torch.no_grad(): |
||||
x = torch.rand(2, 3, 224, 224) |
||||
print(c(x).shape) |
||||
print([None if f is None else f.shape for f in c(x, pyramid=1)]) |
||||
print([None if f is None else f.shape for f in c(x, pyramid=2)]) |
||||
print([None if f is None else f.shape for f in c(x, pyramid=3)]) |
||||
print([None if f is None else f.shape for f in c(x, pyramid=4)]) |
@ -0,0 +1,105 @@ |
||||
# Copyright (c) ByteDance, Inc. and its affiliates. |
||||
# All rights reserved. |
||||
# |
||||
# This source code is licensed under the license found in the |
||||
# LICENSE file in the root directory of this source tree. |
||||
|
||||
import math |
||||
|
||||
import torch.nn.functional as F |
||||
from timm.models.resnet import ResNet |
||||
|
||||
|
||||
def forward_features(self, x, pyramid: int): # pyramid: 0, 1, 2, 3, 4 |
||||
x = self.conv1(x) |
||||
x = self.bn1(x) |
||||
x = self.act1(x) |
||||
x = self.maxpool(x) |
||||
|
||||
ls = [] |
||||
x = self.layer1(x) |
||||
if pyramid: ls.append(x) |
||||
x = self.layer2(x) |
||||
if pyramid: ls.append(x) |
||||
x = self.layer3(x) |
||||
if pyramid: ls.append(x) |
||||
x = self.layer4(x) |
||||
if pyramid: ls.append(x) |
||||
|
||||
if pyramid: |
||||
for i in range(len(ls)-pyramid-1, -1, -1): |
||||
del ls[i] |
||||
return [None] * (4 - pyramid) + ls |
||||
else: |
||||
return x |
||||
|
||||
|
||||
def forward(self, x, pyramid=0): |
||||
if pyramid == 0: |
||||
x = self.forward_features(x, pyramid=pyramid) |
||||
x = self.global_pool(x) |
||||
if self.drop_rate: |
||||
x = F.dropout(x, p=float(self.drop_rate), training=self.training) |
||||
x = self.fc(x) |
||||
return x |
||||
else: |
||||
return self.forward_features(x, pyramid=pyramid) |
||||
|
||||
|
||||
def resnets_get_layer_id_and_scale_exp(self, para_name: str): |
||||
# stages: |
||||
# 50 : [3, 4, 6, 3] |
||||
# 101 : [3, 4, 23, 3] |
||||
# 152 : [3, 8, 36, 3] |
||||
# 200 : [3, 24, 36, 3] |
||||
# eca269d: [3, 30, 48, 8] |
||||
|
||||
L2, L3 = len(self.layer2), len(self.layer3) |
||||
if L2 == 4 and L3 == 6: |
||||
blk2, blk3 = 2, 3 |
||||
elif L2 == 4 and L3 == 23: |
||||
blk2, blk3 = 2, 3 |
||||
elif L2 == 8 and L3 == 36: |
||||
blk2, blk3 = 4, 4 |
||||
elif L2 == 24 and L3 == 36: |
||||
blk2, blk3 = 4, 4 |
||||
elif L2 == 30 and L3 == 48: |
||||
blk2, blk3 = 5, 6 |
||||
else: |
||||
raise NotImplementedError |
||||
|
||||
N2, N3 = math.ceil(L2 / blk2 - 1e-5), math.ceil(L3 / blk3 - 1e-5) |
||||
N = 2 + N2 + N3 |
||||
if para_name.startswith('layer'): # 1, 2, 3, 4, 5 |
||||
stage_id, block_id = int(para_name.split('.')[0][5:]), int(para_name.split('.')[1]) |
||||
if stage_id == 1: |
||||
layer_id = 1 |
||||
elif stage_id == 2: |
||||
layer_id = 2 + block_id // blk2 # 2, 3 |
||||
elif stage_id == 3: |
||||
layer_id = 2 + N2 + block_id // blk3 # r50: 4, 5 r101: 4, 5, ..., 11 |
||||
else: # == 4 |
||||
layer_id = N # r50: 6 r101: 12 |
||||
elif para_name.startswith('fc.'): |
||||
layer_id = N+1 # r50: 7 r101: 13 |
||||
else: |
||||
layer_id = 0 |
||||
|
||||
return layer_id, N+1 - layer_id # r50: 0-7, 7-0 r101: 0-13, 13-0 |
||||
|
||||
|
||||
ResNet.get_layer_id_and_scale_exp = resnets_get_layer_id_and_scale_exp |
||||
ResNet.forward_features = forward_features |
||||
ResNet.forward = forward |
||||
|
||||
|
||||
if __name__ == '__main__': |
||||
import torch |
||||
from timm.models import create_model |
||||
r = create_model('resnet50') |
||||
with torch.no_grad(): |
||||
print(r(torch.rand(2, 3, 224, 224)).shape) |
||||
print(r(torch.rand(2, 3, 224, 224), pyramid=1)) |
||||
print(r(torch.rand(2, 3, 224, 224), pyramid=2)) |
||||
print(r(torch.rand(2, 3, 224, 224), pyramid=3)) |
||||
print(r(torch.rand(2, 3, 224, 224), pyramid=4)) |
@ -0,0 +1,167 @@ |
||||
# Copyright (c) ByteDance, Inc. and its affiliates. |
||||
# All rights reserved. |
||||
# |
||||
# This source code is licensed under the license found in the |
||||
# LICENSE file in the root directory of this source tree. |
||||
|
||||
import math |
||||
import random |
||||
|
||||
import numpy as np |
||||
import torch |
||||
from torch.utils.data.sampler import Sampler |
||||
|
||||
import dist |
||||
|
||||
|
||||
def worker_init_fn(worker_id): |
||||
# https://pytorch.org/docs/stable/notes/randomness.html#dataloader |
||||
worker_seed = torch.initial_seed() % 2 ** 32 |
||||
np.random.seed(worker_seed) |
||||
random.seed(worker_seed) |
||||
|
||||
|
||||
class RASampler(Sampler): |
||||
"""Sampler that restricts data loading to a subset of the dataset for distributed, |
||||
with repeated augmentation. |
||||
It ensures that different each augmented version of a sample will be visible to a |
||||
different process (GPU). |
||||
Heavily based on 'torch.utils.data.DistributedSampler'. |
||||
This is borrowed from the DeiT Repo: |
||||
https://github.com/facebookresearch/deit/blob/main/samplers.py |
||||
""" |
||||
|
||||
def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True, seed=0, repetitions=3): |
||||
if num_replicas is None: |
||||
num_replicas = dist.get_world_size() |
||||
if rank is None: |
||||
rank = dist.get_rank() |
||||
self.dataset = dataset |
||||
self.num_replicas = num_replicas |
||||
self.rank = rank |
||||
self.epoch = 0 |
||||
self.num_samples = int(math.ceil(len(self.dataset) * float(repetitions) / self.num_replicas)) |
||||
self.total_size = self.num_samples * self.num_replicas |
||||
self.num_selected_samples = int(math.floor(len(self.dataset) // 256 * 256 / self.num_replicas)) |
||||
self.shuffle = shuffle |
||||
self.seed = seed |
||||
self.repetitions = repetitions |
||||
|
||||
def __iter__(self): |
||||
if self.shuffle: |
||||
# Deterministically shuffle based on epoch |
||||
g = torch.Generator() |
||||
g.manual_seed(self.seed + self.epoch) |
||||
indices = torch.randperm(len(self.dataset), generator=g).tolist() |
||||
else: |
||||
indices = list(range(len(self.dataset))) |
||||
|
||||
# Add extra samples to make it evenly divisible |
||||
indices = [ele for ele in indices for i in range(self.repetitions)] |
||||
indices += indices[: (self.total_size - len(indices))] |
||||
assert len(indices) == self.total_size |
||||
|
||||
# Subsample |
||||
indices = indices[self.rank : self.total_size : self.num_replicas] |
||||
assert len(indices) == self.num_samples |
||||
|
||||
return iter(indices[: self.num_selected_samples]) |
||||
|
||||
def __len__(self): |
||||
return self.num_selected_samples |
||||
|
||||
def set_epoch(self, epoch): |
||||
self.epoch = epoch |
||||
|
||||
|
||||
class InfiniteBatchSampler(Sampler): |
||||
def __init__(self, dataset_len, batch_size, seed=0, filling=False, shuffle=True, drop_last=False): |
||||
self.dataset_len = dataset_len |
||||
self.batch_size = batch_size |
||||
self.iters_per_ep = dataset_len // batch_size if drop_last else (dataset_len + batch_size - 1) // batch_size |
||||
self.max_p = self.iters_per_ep * batch_size |
||||
self.filling = filling |
||||
self.shuffle = shuffle |
||||
self.epoch = 0 |
||||
self.seed = seed |
||||
self.indices = self.gener_indices() |
||||
|
||||
def gener_indices(self): |
||||
if self.shuffle: |
||||
g = torch.Generator() |
||||
g.manual_seed(self.epoch + self.seed) |
||||
indices = torch.randperm(self.dataset_len, generator=g).numpy() |
||||
else: |
||||
indices = torch.arange(self.dataset_len).numpy() |
||||
|
||||
tails = self.batch_size - (self.dataset_len % self.batch_size) |
||||
if tails != self.batch_size and self.filling: |
||||
tails = indices[:tails] |
||||
np.random.shuffle(indices) |
||||
indices = np.concatenate((indices, tails)) |
||||
|
||||
# built-in list/tuple is faster than np.ndarray (when collating the data via a for-loop) |
||||
# noinspection PyTypeChecker |
||||
return tuple(indices.tolist()) |
||||
|
||||
def __iter__(self): |
||||
self.epoch = 0 |
||||
while True: |
||||
self.epoch += 1 |
||||
p, q = 0, 0 |
||||
while p < self.max_p: |
||||
q = p + self.batch_size |
||||
yield self.indices[p:q] |
||||
p = q |
||||
if self.shuffle: |
||||
self.indices = self.gener_indices() |
||||
|
||||
def __len__(self): |
||||
return self.iters_per_ep |
||||
|
||||
|
||||
class DistInfiniteBatchSampler(InfiniteBatchSampler): |
||||
def __init__(self, world_size, rank, dataset_len, glb_batch_size, seed=0, repeated_aug=0, filling=False, shuffle=True): |
||||
# from torchvision.models import ResNet50_Weights |
||||
# RA sampler: https://github.com/pytorch/vision/blob/5521e9d01ee7033b9ee9d421c1ef6fb211ed3782/references/classification/sampler.py |
||||
|
||||
assert glb_batch_size % world_size == 0 |
||||
self.world_size, self.rank = world_size, rank |
||||
self.dataset_len = dataset_len |
||||
self.glb_batch_size = glb_batch_size |
||||
self.batch_size = glb_batch_size // world_size |
||||
|
||||
self.iters_per_ep = (dataset_len + glb_batch_size - 1) // glb_batch_size |
||||
self.filling = filling |
||||
self.shuffle = shuffle |
||||
self.repeated_aug = repeated_aug |
||||
self.epoch = 0 |
||||
self.seed = seed |
||||
self.indices = self.gener_indices() |
||||
|
||||
def gener_indices(self): |
||||
global_max_p = self.iters_per_ep * self.glb_batch_size # global_max_p % world_size must be 0 cuz glb_batch_size % world_size == 0 |
||||
if self.shuffle: |
||||
g = torch.Generator() |
||||
g.manual_seed(self.epoch + self.seed) |
||||
global_indices = torch.randperm(self.dataset_len, generator=g) |
||||
if self.repeated_aug > 1: |
||||
global_indices = global_indices[:(self.dataset_len + self.repeated_aug - 1) // self.repeated_aug].repeat_interleave(self.repeated_aug, dim=0)[:global_max_p] |
||||
else: |
||||
global_indices = torch.arange(self.dataset_len) |
||||
filling = global_max_p - global_indices.shape[0] |
||||
if filling > 0 and self.filling: |
||||
global_indices = torch.cat((global_indices, global_indices[:filling])) |
||||
global_indices = tuple(global_indices.numpy().tolist()) |
||||
|
||||
seps = torch.linspace(0, len(global_indices), self.world_size + 1, dtype=torch.int) |
||||
local_indices = global_indices[seps[self.rank]:seps[self.rank + 1]] |
||||
self.max_p = len(local_indices) |
||||
return local_indices |
||||
|
||||
|
||||
if __name__ == '__main__': |
||||
W = 16 |
||||
for rk in range(W): |
||||
ind = DistInfiniteBatchSampler(W, rk, 5024, 5024).gener_indices() |
||||
print(rk, len(ind)) |
@ -0,0 +1,49 @@ |
||||
#!/usr/bin/env bash |
||||
|
||||
# an example to do pre-training: (not that `/path/to/imagenet` should contain directories named `train` and `val`) |
||||
# > cd /path/to/SparK |
||||
# > bash ./scripts/pt.sh experiment_name /path/to/imagenet --num_nodes=1 --ngpu_per_node=8 --node_rank=0 --master_address=128.0.0.0 --master_port=30000 --model=res50 --ep=400 |
||||
|
||||
####### template begins ####### |
||||
SCRIPTS_DIR=$(cd $(dirname $0); pwd) |
||||
cd "${SCRIPTS_DIR}" |
||||
cd ../ |
||||
SPARK_DIR=$(pwd) |
||||
echo "SPARK_DIR=${SPARK_DIR}" |
||||
|
||||
shopt -s expand_aliases |
||||
alias python=python3 |
||||
alias to_scripts_dir='cd "${SCRIPTS_DIR}"' |
||||
alias to_spark_dir='cd "${SPARK_DIR}"' |
||||
alias print='echo "$(date +"[%m-%d %H:%M:%S]") (exp.sh)=>"' |
||||
function mkd() { |
||||
mkdir -p "$1" >/dev/null 2>&1 |
||||
} |
||||
####### template ends ####### |
||||
|
||||
|
||||
EXP_NAME=$1 |
||||
DATA_PATH=$2 |
||||
|
||||
EXP_DIR="${SPARK_DIR}/output_${EXP_NAME}" |
||||
mkd "${EXP_DIR}" |
||||
|
||||
|
||||
print "===================== Args =====================" |
||||
print "EXP_NAME: ${EXP_NAME}" |
||||
print "DATA_PATH: ${DATA_PATH}" |
||||
print "EXP_DIR: ${EXP_DIR}" |
||||
print "[other_args sent to launch.py]: ${*:3}" |
||||
print "================================================" |
||||
print "" |
||||
|
||||
|
||||
print "============== Pretraining starts ==============" |
||||
to_spark_dir |
||||
python launch.py \ |
||||
--main_py_relpath main.py \ |
||||
--exp_name "${EXP_NAME}" \ |
||||
--data_path "${DATA_PATH}" \ |
||||
--exp_dir "${EXP_DIR}" \ |
||||
"${*:3}" |
||||
print "============== Pretraining ends ==============" |
@ -0,0 +1,313 @@ |
||||
# Copyright (c) ByteDance, Inc. and its affiliates. |
||||
# All rights reserved. |
||||
# |
||||
# This source code is licensed under the license found in the |
||||
# LICENSE file in the root directory of this source tree. |
||||
|
||||
import random |
||||
from pprint import pformat |
||||
from typing import List |
||||
|
||||
import numpy as np |
||||
import torch |
||||
import torch.nn as nn |
||||
from timm.data import IMAGENET_DEFAULT_STD, IMAGENET_DEFAULT_MEAN |
||||
from timm.models.layers import trunc_normal_ |
||||
|
||||
import encoder |
||||
from decoder import LightDecoder |
||||
|
||||
|
||||
class SparK(nn.Module): |
||||
def __init__( |
||||
self, sparse_encoder: encoder.SparseEncoder, dense_decoder: LightDecoder, |
||||
mask_ratio=0.6, mask_ratio2=0.6, uniform=False, |
||||
using_pe=True, pix_norm=0, dense_loss=False, loss_l2=True, |
||||
en_de_norm='bn', en_de_lin=True, sbn=False, pyramid=1, # 1 for single-scale pre-training; 4 for full-scale pre-training |
||||
): |
||||
super().__init__() |
||||
input_size, downsample_raito = sparse_encoder.input_size, sparse_encoder.downsample_raito |
||||
self.downsample_raito = downsample_raito |
||||
fmap_size = input_size // downsample_raito |
||||
self.fmap_size = fmap_size |
||||
if mask_ratio != mask_ratio2 and not uniform: # with an extra active site |
||||
k = 1 / fmap_size**2 |
||||
mask_ratio = min(1, mask_ratio / (1-k)) |
||||
mask_ratio2 = min(1, mask_ratio2 / (1-k)) |
||||
self.mask_ratio = (mask_ratio, mask_ratio2) |
||||
self.ratios = torch.tensor([self.mask_ratio[0], self.mask_ratio[1], (self.mask_ratio[0] + self.mask_ratio[1]) / 2]) |
||||
self.uniform = uniform |
||||
self.len_keep = round(fmap_size * fmap_size * (1-mask_ratio)) |
||||
self.pix_norm = int(pix_norm) |
||||
|
||||
self.sparse_encoder = sparse_encoder |
||||
self.dense_decoder = dense_decoder |
||||
|
||||
self.sbn = sbn |
||||
self.pyramid = pyramid |
||||
en_de_norm = en_de_norm.lower() |
||||
self.en_de_norm_str = en_de_norm |
||||
|
||||
self.en_de_lin_bool = en_de_lin |
||||
self.en_de_norms = nn.ModuleList() |
||||
self.en_de_lins = nn.ModuleList() |
||||
self.using_pe = using_pe |
||||
self.pos_embeds = nn.ParameterList() |
||||
self.mask_tokens = nn.ParameterList() |
||||
|
||||
fea, d_fea, fmap = self.sparse_encoder.fea_dim, self.dense_decoder.fea_dim, fmap_size |
||||
for i in range(self.pyramid): |
||||
if en_de_norm == 'bn': |
||||
n = (encoder.SparseSyncBatchNorm2d if sbn else encoder.SparseBatchNorm2d)(fea) |
||||
elif en_de_norm == 'ln': |
||||
n = encoder.SparseConvNeXtLayerNorm(fea, data_format='channels_first', sparse=True) |
||||
else: |
||||
n = nn.Identity() |
||||
self.en_de_norms.append(n) |
||||
|
||||
kernel_size = 1 if i <= 0 else 3 |
||||
l = nn.Conv2d(fea, d_fea, kernel_size=kernel_size, stride=1, padding=kernel_size//2, bias=True) |
||||
print(f'[mid, py={self.pyramid}][edl {i}]: k={kernel_size}, #para = {sum(x.numel() for x in l.parameters())/1e6:.2f}') |
||||
if i == 0 and fea == d_fea: |
||||
l = nn.Identity() |
||||
self.en_de_lins.append(l) |
||||
|
||||
if self.using_pe: |
||||
p = torch.from_numpy(get_2d_sincos_pos_embed(fea, fmap)).float() |
||||
p = p.reshape(1, fmap, fmap, fea).permute(0, 3, 1, 2).contiguous() |
||||
p = nn.Parameter(p, requires_grad=False) |
||||
self.pos_embeds.append(p) |
||||
|
||||
p = nn.Parameter(torch.zeros(1, fea, 1, 1)) |
||||
trunc_normal_(p, mean=0, std=.02, a=-.02, b=.02) |
||||
self.mask_tokens.append(p) |
||||
fea //= 2 |
||||
d_fea //= 2 |
||||
fmap *= 2 |
||||
|
||||
print(f'[mid, py={self.pyramid}][mask_tokens]: {tuple(p.numel() for p in self.mask_tokens)}') |
||||
|
||||
self.loss_l2, self.dense_loss = loss_l2, dense_loss |
||||
|
||||
m = torch.tensor(IMAGENET_DEFAULT_MEAN).view(1, 3, 1, 1) |
||||
s = torch.tensor(IMAGENET_DEFAULT_STD).view(1, 3, 1, 1) |
||||
self.register_buffer('imn_m', m) |
||||
self.register_buffer('imn_s', s) |
||||
# self.register_buffer('norm_black', (torch.ones(1, 3, input_size, input_size) * 0.45 - m) / s) |
||||
self.register_buffer('norm_black', torch.zeros(1, 3, input_size, input_size)) |
||||
self.vis_active = self.vis_active_ex = self.vis_inp = self.vis_inp_mask = ... |
||||
|
||||
def mask(self, shape, device, generator=None): |
||||
B, C, H, W = shape |
||||
f = self.fmap_size |
||||
if self.mask_ratio[0] == self.mask_ratio[1]: |
||||
len_keep = self.len_keep |
||||
elif self.uniform: |
||||
r = random.uniform(self.mask_ratio[0], self.mask_ratio[1]) |
||||
len_keep = round(f * f * (1-r)) |
||||
else: |
||||
i1, i2, i3, i4 = np.linspace(0, B, 4, dtype=int).tolist() |
||||
l1, l2, l3 = i2-i1, i3-i2, i4-i3 |
||||
r1, r2, r3 = self.ratios[torch.randperm(3, generator=generator)].tolist() |
||||
r = torch.tensor([r1]*l1 + [r2]*l2 + [r3]*l3, device=device).view(-1, 1, 1) |
||||
active = torch.rand(B, f, f, device=device, generator=generator) >= r |
||||
|
||||
rr, cc = torch.randint(low=0, high=f, size=(2, B), generator=generator).unbind(0) |
||||
active[torch.arange(B), rr, cc] = True # an extra active site |
||||
return active.unsqueeze_(1) |
||||
|
||||
idx = torch.rand(B, f*f, generator=generator).argsort(dim=1) |
||||
idx = idx[:, :len_keep].to(device) # (B, len_keep) |
||||
return torch.zeros(B, f*f, dtype=torch.bool, device=device).scatter_(dim=1, index=idx, value=True).view(B, 1, f, f) |
||||
|
||||
def forward(self, raw_inp: torch.Tensor, active=None): |
||||
inp_bchw = raw_inp |
||||
|
||||
# spatial mask |
||||
if active is None: |
||||
active: torch.BoolTensor = self.mask(inp_bchw.shape, inp_bchw.device) # (B, 1, f, f) |
||||
encoder._cur_active = active |
||||
active_ex = active.repeat_interleave(self.downsample_raito, 2).repeat_interleave(self.downsample_raito, 3) # (B, 1, H, W) |
||||
masked_bchw = inp_bchw * active_ex |
||||
|
||||
# get hierarchical encoded sparse features (a list containing four feature maps) |
||||
fea_bcffs: List[torch.Tensor] = self.sparse_encoder(masked_bchw, pyramid=self.pyramid) |
||||
fea_bcffs.reverse() # from the smallest feature map to the largest |
||||
|
||||
cur_active = active |
||||
to_dec = [] |
||||
for i, bcff in enumerate(fea_bcffs): # from the smallest feature map to the largest |
||||
if bcff is not None: |
||||
# fill in empty positions with [mask] embeddings |
||||
bcff = self.en_de_norms[i](bcff) |
||||
mask_tokens = self.mask_tokens[i].expand_as(bcff) |
||||
if self.using_pe: |
||||
mask_tokens = mask_tokens + self.pos_embeds[i].expand_as(bcff) |
||||
bcff = torch.where(cur_active.expand_as(bcff), bcff, mask_tokens) |
||||
bcff: torch.Tensor = self.en_de_lins[i](bcff) |
||||
to_dec.append(bcff) |
||||
cur_active = cur_active.repeat_interleave(2, dim=2).repeat_interleave(2, dim=3) # dilate the mask map |
||||
|
||||
# decode and reconstruct |
||||
rec_bchw = self.dense_decoder(to_dec) |
||||
|
||||
# calc loss |
||||
mean, var, spatial_loss = self.spatial_loss(raw_inp, rec_bchw, active) |
||||
return active_ex, rec_bchw, spatial_loss |
||||
|
||||
def spatial_loss(self, inp, rec, active): # active: (B, 1, f, f) |
||||
mean = var = None |
||||
if self.pix_norm == 2: |
||||
mean, var = inp.mean(dim=(2, 3), keepdim=True), None |
||||
rec = rec + mean |
||||
inp = self.patchify(inp) |
||||
rec = self.patchify(rec) |
||||
# (B, L=fmap_size**2, N=downsample_raito**2 * C) |
||||
if self.pix_norm == 1: |
||||
mean = inp.mean(dim=-1, keepdim=True) |
||||
var = (inp.var(dim=-1, keepdim=True) + 1e-6)**.5 |
||||
inp = (inp - mean) / var |
||||
loss_spa = (rec-inp)**2 if self.loss_l2 else (rec-inp).abs() |
||||
|
||||
if self.dense_loss: |
||||
return mean, var, loss_spa.mean() # mean loss on all patches |
||||
else: |
||||
loss_spa = loss_spa.mean(dim=2, keepdim=False) # (B, L, C) => (B, L) |
||||
non_active = active.logical_not().int().view(active.shape[0], -1) # (B, 1, f, f) => (B, L) |
||||
return mean, var, loss_spa.mul_(non_active).sum() / (non_active.sum() + 1e-8) # mean loss on removed patches |
||||
|
||||
def patchify(self, bchw): |
||||
p = self.downsample_raito |
||||
h = w = self.fmap_size |
||||
B, C = bchw.shape[:2] |
||||
bchw = bchw.reshape(shape=(B, C, h, p, w, p)) |
||||
bchw = torch.einsum('bchpwq->bhwpqc', bchw) |
||||
bln = bchw.reshape(shape=(B, h*w, p**2 * C)) # (B, f*f, downsample_raito*downsample_raito*3) |
||||
return bln |
||||
|
||||
def unpatchify(self, bln): |
||||
p = self.downsample_raito |
||||
h = w = self.fmap_size |
||||
B, C = bln.shape[0], bln.shape[-1] // p**2 |
||||
bln = bln.reshape(shape=(B, h, w, p, p, C)) |
||||
bln = torch.einsum('bhwpqc->bchpwq', bln) |
||||
bchw = bln.reshape(shape=(B, C, h * p, w * p)) |
||||
return bchw |
||||
|
||||
def denorm_for_vis(self, x, clamp): |
||||
x = x * self.imn_s |
||||
x += self.imn_m |
||||
if clamp: |
||||
x = torch.clamp(x, 0, 1) |
||||
return x |
||||
|
||||
def __repr__(self): |
||||
return ( |
||||
f'\n' |
||||
f'[SparK.config]: {pformat(self.get_config(), indent=2, width=250)}\n' |
||||
f'[SparK.structure]: {super(SparK, self).__repr__().replace(SparK.__name__, "")}\n' |
||||
f'[SparK.dec]: {self.dense_decoder.num_para()}' |
||||
) |
||||
|
||||
def get_config(self): |
||||
return { |
||||
# self |
||||
'mask_ratio': self.mask_ratio[0], 'mask_ratio2': self.mask_ratio[1], 'uniform': self.uniform, |
||||
'using_pe': self.using_pe, 'pix_norm': self.pix_norm, |
||||
'dense_loss': self.dense_loss, 'loss_l2': self.loss_l2, |
||||
'en_de_norm': self.en_de_norm_str, 'en_de_lin': self.en_de_lin_bool, 'sbn': self.sbn, 'pyramid': self.pyramid, |
||||
|
||||
# enc |
||||
'input_size': self.sparse_encoder.input_size, |
||||
# dec |
||||
'dec_fea_dim': self.dense_decoder.fea_dim, 'double': self.dense_decoder.double_bool, 'heavy': self.dense_decoder.heavy, |
||||
} |
||||
|
||||
def state_dict(self, destination=None, prefix='', keep_vars=False, with_config=False): |
||||
state = super(SparK, self).state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars) |
||||
if with_config: |
||||
state['config'] = self.get_config() # todo: 似乎会引起DDP broadcast err?? |
||||
return state |
||||
|
||||
def load_state_dict(self, state_dict, strict=True): |
||||
config = state_dict.pop('config', None) |
||||
incompatible_keys = super(SparK, self).load_state_dict(state_dict, strict=strict) |
||||
if config is not None: |
||||
for k, v in self.get_config().items(): |
||||
if config.get(k, None) != v: |
||||
err = f'[SparseMIM.load_state_dict] config mismatch: this.{k}={v} (ckpt.{k}={config.get(k, None)})' |
||||
if strict: |
||||
raise AttributeError(err) |
||||
else: |
||||
print(err) |
||||
|
||||
return incompatible_keys |
||||
|
||||
|
||||
def _make_divisible(v, divisor=8, min_value=None): |
||||
""" |
||||
This function is taken from the original tf repo. |
||||
It ensures that all layers have a channel number that is divisible by 8 |
||||
It can be seen here: |
||||
https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py |
||||
""" |
||||
if min_value is None: |
||||
min_value = divisor |
||||
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) |
||||
# Make sure that round down does not go down by more than 10%. |
||||
if new_v < 0.9 * v: |
||||
new_v += divisor |
||||
return new_v |
||||
|
||||
|
||||
def get_2d_sincos_pos_embed(embed_dim, grid_size): |
||||
""" |
||||
grid_size: int of the grid height and width |
||||
return: |
||||
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) |
||||
""" |
||||
grid_h = np.arange(grid_size, dtype=np.float32) |
||||
grid_w = np.arange(grid_size, dtype=np.float32) |
||||
grid = np.meshgrid(grid_w, grid_h) # here w goes first |
||||
grid = np.stack(grid, axis=0) |
||||
|
||||
grid = grid.reshape([2, 1, grid_size, grid_size]) |
||||
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) |
||||
return pos_embed |
||||
|
||||
|
||||
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): |
||||
assert embed_dim % 2 == 0 |
||||
|
||||
# use half of dimensions to encode grid_h |
||||
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) |
||||
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) |
||||
|
||||
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) |
||||
return emb |
||||
|
||||
|
||||
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): |
||||
""" |
||||
embed_dim: output dimension for each position |
||||
pos: a list of positions to be encoded: size (M,) |
||||
out: (M, D) |
||||
""" |
||||
assert embed_dim % 2 == 0 |
||||
omega = np.arange(embed_dim // 2, dtype=np.float) |
||||
omega /= embed_dim / 2. |
||||
omega = 1. / 10000**omega # (D/2,) |
||||
|
||||
pos = pos.reshape(-1) # (M,) |
||||
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product |
||||
|
||||
emb_sin = np.sin(out) # (M, D/2) |
||||
emb_cos = np.cos(out) # (M, D/2) |
||||
|
||||
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) |
||||
return emb |
||||
|
||||
|
||||
if __name__ == '__main__': |
||||
SparK.test_mask() |
||||
SparK.test_align() |
@ -0,0 +1,172 @@ |
||||
# Copyright (c) ByteDance, Inc. and its affiliates. |
||||
# All rights reserved. |
||||
# |
||||
# This source code is licensed under the license found in the |
||||
# LICENSE file in the root directory of this source tree. |
||||
|
||||
import os |
||||
from typing import Any, Callable, Optional, Tuple |
||||
|
||||
import PIL.Image as PImage |
||||
import torch |
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, create_transform |
||||
from timm.data.transforms_factory import transforms_imagenet_eval |
||||
from torchvision.datasets.folder import DatasetFolder, IMG_EXTENSIONS |
||||
from torchvision.transforms import transforms |
||||
|
||||
import dist |
||||
|
||||
try: |
||||
from torchvision.transforms import InterpolationMode |
||||
interpolation = InterpolationMode.BICUBIC |
||||
except: |
||||
import PIL |
||||
interpolation = PIL.Image.BICUBIC |
||||
|
||||
|
||||
def pil_loader(path): |
||||
# open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) |
||||
with open(path, 'rb') as f: img: PImage.Image = PImage.open(f).convert('RGB') |
||||
return img |
||||
|
||||
|
||||
class ImageNetDataset(DatasetFolder): |
||||
def __init__( |
||||
self, |
||||
root: str, |
||||
train: bool, |
||||
transform: Optional[Callable] = None, |
||||
target_transform: Optional[Callable] = None, |
||||
is_valid_file: Optional[Callable[[str], bool]] = None, |
||||
max_cls_id: int = 1000, |
||||
only=-1, |
||||
): |
||||
for postfix in (os.path.sep, 'train', 'val'): |
||||
if root.endswith(postfix): |
||||
root = root[:-len(postfix)] |
||||
|
||||
root = os.path.join(root, 'train' if train else 'val') |
||||
|
||||
super(ImageNetDataset, self).__init__( |
||||
root, |
||||
# loader=ImageLoader(train), |
||||
loader=pil_loader, |
||||
extensions=IMG_EXTENSIONS if is_valid_file is None else None, |
||||
transform=transform, target_transform=target_transform, is_valid_file=is_valid_file |
||||
) |
||||
|
||||
if only > 0: |
||||
g = torch.Generator() |
||||
g.manual_seed(0) |
||||
idx = torch.randperm(len(self.samples), generator=g).numpy().tolist() |
||||
|
||||
ws = dist.get_world_size() |
||||
res = (max_cls_id * only) % ws |
||||
more = 0 if res == 0 else (ws - res) |
||||
max_total = max_cls_id * only + more |
||||
if (max_total // ws) % 2 == 1: |
||||
more += ws |
||||
max_total += ws |
||||
|
||||
d = {c: [] for c in range(max_cls_id)} |
||||
max_len = {c: only for c in range(max_cls_id)} |
||||
for c in range(max_cls_id-more, max_cls_id): |
||||
max_len[c] += 1 |
||||
|
||||
total = 0 |
||||
for i in idx: |
||||
path, target = self.samples[i] |
||||
if len(d[target]) < max_len[target]: |
||||
d[target].append((path, target)) |
||||
total += 1 |
||||
if total == max_total: |
||||
break |
||||
sp = [] |
||||
[sp.extend(l) for l in d.values()] |
||||
|
||||
print(f'[ds] more={more}, len(sp)={len(sp)}') |
||||
self.samples = tuple(sp) |
||||
self.targets = tuple([s[1] for s in self.samples]) |
||||
else: |
||||
self.samples = tuple(filter(lambda item: item[-1] < max_cls_id, self.samples)) |
||||
self.targets = tuple([s[1] for s in self.samples]) |
||||
|
||||
def __getitem__(self, index: int) -> Tuple[Any, Any]: |
||||
path, target = self.samples[index] |
||||
sample = self.loader(path) |
||||
if self.transform is not None: |
||||
sample = self.transform(sample) |
||||
if self.target_transform is not None: |
||||
target = self.target_transform(target) |
||||
|
||||
return sample, target |
||||
|
||||
|
||||
def build_imagenet(mode, data_path, data_set, img_size, eval_crop_pct=None, rrc=0.3, aa='rand-m7-mstd0.5', re_prob=0.0, colorj=0.4): |
||||
mean, std = IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD |
||||
norm = transforms.Normalize(mean=mean, std=std) |
||||
|
||||
if img_size >= 384: |
||||
trans_val = transforms.Compose([ |
||||
transforms.Resize((img_size, img_size), interpolation=interpolation), |
||||
transforms.ToTensor(), |
||||
norm, |
||||
]) |
||||
else: |
||||
trans_val = transforms_imagenet_eval( |
||||
img_size=img_size, interpolation='bicubic', crop_pct=eval_crop_pct, |
||||
mean=mean, std=std |
||||
) |
||||
|
||||
mode = mode.lower() |
||||
if mode == 'pt': |
||||
trans_train = transforms.Compose([ |
||||
transforms.RandomResizedCrop(img_size, scale=(rrc, 1.0), interpolation=interpolation), |
||||
transforms.RandomHorizontalFlip(), |
||||
transforms.ToTensor(), |
||||
norm, |
||||
]) |
||||
elif mode == 'le': |
||||
trans_train = transforms.Compose([ |
||||
transforms.RandomResizedCrop(img_size, interpolation=interpolation), |
||||
transforms.RandomHorizontalFlip(), |
||||
transforms.ToTensor(), |
||||
norm, |
||||
]) |
||||
else: |
||||
trans_train = create_transform( |
||||
is_training=True, |
||||
input_size=img_size, |
||||
auto_augment=aa, |
||||
interpolation='bicubic', |
||||
re_prob=re_prob, |
||||
re_mode='pixel', |
||||
re_count=1, |
||||
color_jitter=colorj, |
||||
mean=mean, std=std, |
||||
) |
||||
|
||||
if data_path.endswith(os.path.sep): |
||||
data_path = data_path[:-len(os.path.sep)] |
||||
for postfix in ('train', 'val'): |
||||
if data_path.endswith(postfix): |
||||
data_path = data_path[:-len(postfix)] |
||||
|
||||
if data_set == 'imn': |
||||
dataset_train = ImageNetDataset(root=data_path, transform=trans_train, train=True) |
||||
dataset_val = ImageNetDataset(root=data_path, transform=trans_val, train=False) |
||||
num_classes = 1000 |
||||
else: |
||||
raise NotImplementedError |
||||
|
||||
print_transform(trans_train, '[train]') |
||||
print_transform(trans_val, '[val]') |
||||
|
||||
return dataset_train, dataset_val |
||||
|
||||
|
||||
def print_transform(transform, s): |
||||
print(f'Transform {s} = ') |
||||
for t in transform.transforms: |
||||
print(t) |
||||
print('---------------------------\n') |
@ -0,0 +1,73 @@ |
||||
# Copyright (c) ByteDance, Inc. and its affiliates. |
||||
# All rights reserved. |
||||
# |
||||
# This source code is licensed under the license found in the |
||||
# LICENSE file in the root directory of this source tree. |
||||
|
||||
import math |
||||
from pprint import pformat |
||||
|
||||
|
||||
def lr_wd_annealing(optimizer, peak_lr, wd, wd_end, cur_it, wp_it, max_it): |
||||
wp_it = round(wp_it) |
||||
|
||||
if cur_it < wp_it: |
||||
cur_lr = 0.005 * peak_lr + 0.995 * peak_lr * cur_it / wp_it |
||||
else: |
||||
ratio = (cur_it - wp_it) / (max_it-1 - wp_it) |
||||
cur_lr = 0.001 * peak_lr + 0.999 * peak_lr * (0.5 + 0.5 * math.cos(math.pi * ratio)) |
||||
|
||||
ratio = cur_it / (max_it-1) |
||||
cur_wd = wd_end + (wd - wd_end) * (0.5 + 0.5 * math.cos(math.pi * ratio)) |
||||
|
||||
inf = 1e6 |
||||
min_lr, max_lr = inf, -1 |
||||
min_wd, max_wd = inf, -1 |
||||
for param_group in optimizer.param_groups: |
||||
param_group['lr'] = cur_lr * param_group.get('lr_scale', 1) # 'lr_scale' could be assigned |
||||
max_lr = max(max_lr, param_group['lr']) |
||||
min_lr = min(min_lr, param_group['lr']) |
||||
|
||||
param_group['weight_decay'] = cur_wd * param_group.get('weight_decay_scale', 1) |
||||
max_wd = max(max_wd, param_group['weight_decay']) |
||||
if param_group['weight_decay'] > 0: |
||||
min_wd = min(min_wd, param_group['weight_decay']) |
||||
|
||||
if min_lr == inf: min_lr = -1 |
||||
if min_wd == inf: min_wd = -1 |
||||
return min_lr, max_lr, min_wd, max_wd |
||||
|
||||
|
||||
def get_param_groups(model, nowd_keys=(), lr_scale=0.0): |
||||
with_lr_scale = hasattr(model, 'get_layer_id_and_scale_exp') and 0 < lr_scale < 1 |
||||
print(f'[get_ft_param_groups][lr decay] with_lr_scale={with_lr_scale}, ft_lr_scale={lr_scale}') |
||||
para_groups, para_groups_dbg = {}, {} |
||||
|
||||
for name, para in model.named_parameters(): |
||||
if not para.requires_grad: |
||||
continue # frozen weights |
||||
if len(para.shape) == 1 or name.endswith('.bias') or any(k in name for k in nowd_keys): |
||||
wd_scale, group_name = 0., 'no_decay' |
||||
else: |
||||
wd_scale, group_name = 1., 'decay' |
||||
|
||||
if with_lr_scale: |
||||
layer_id, scale_exp = model.get_layer_id_and_scale_exp(name) |
||||
group_name = f'layer{layer_id}_' + group_name |
||||
cur_lr_scale = lr_scale ** scale_exp |
||||
dbg = f'[layer {layer_id}][sc = {lr_scale} ** {scale_exp}]' |
||||
else: |
||||
cur_lr_scale = 1 |
||||
dbg = f'[no scale]' |
||||
|
||||
if group_name not in para_groups: |
||||
para_groups[group_name] = {'params': [], 'weight_decay_scale': wd_scale, 'lr_scale': cur_lr_scale} |
||||
para_groups_dbg[group_name] = {'params': [], 'weight_decay_scale': wd_scale, 'lr_scale': dbg} |
||||
para_groups[group_name]['params'].append(para) |
||||
para_groups_dbg[group_name]['params'].append(name) |
||||
|
||||
for g in para_groups_dbg.values(): |
||||
g['params'] = pformat(', '.join(g['params']), width=200) |
||||
|
||||
print(f'[get_ft_param_groups] param groups = \n{pformat(para_groups_dbg, indent=2, width=250)}\n') |
||||
return list(para_groups.values()) |
@ -0,0 +1,163 @@ |
||||
# Copyright (c) ByteDance, Inc. and its affiliates. |
||||
# All rights reserved. |
||||
# |
||||
# This source code is licensed under the license found in the |
||||
# LICENSE file in the root directory of this source tree. |
||||
|
||||
import json |
||||
import os |
||||
import re |
||||
import sys |
||||
|
||||
from tap import Tap |
||||
|
||||
import dist |
||||
|
||||
line_sep = f'\n{"=" * 80}\n' |
||||
|
||||
|
||||
class Args(Tap): |
||||
# environment |
||||
local_rank: int # useless |
||||
exp_name: str |
||||
data_path: str |
||||
exp_dir: str |
||||
log_txt_name: str = '/some/path/like/this/log.txt' |
||||
resume: str = '' |
||||
seed: int = 1 |
||||
device: str = 'cpu' |
||||
|
||||
# key MIM hp |
||||
mask: float = 0.6 |
||||
mask2: float = -1 |
||||
uni: bool = False |
||||
pe: bool = False |
||||
pn: int = 1 |
||||
py: int = 4 |
||||
# other MIM hp |
||||
den: bool = False |
||||
loss_l2: bool = True |
||||
en_de_norm: str = 'bn' |
||||
en_de_lin: bool = True |
||||
|
||||
# encoder |
||||
model: str = 'res50' |
||||
model_alias: str = 'res50' |
||||
input_size: int = 224 |
||||
sbn: bool = True |
||||
# decoder |
||||
dec_dim: int = 512 # [could be changed in `main.py`] |
||||
double: bool = True |
||||
hea: str = '0_1' |
||||
cmid: int = 0 |
||||
|
||||
# pre-training hyperparameters |
||||
glb_batch_size: int = 0 |
||||
batch_size: int = 0 # batch size per GPU |
||||
dp: float = 0.0 |
||||
base_lr: float = 2e-4 |
||||
lr: float = None |
||||
wd: float = 0.04 |
||||
wde: float = 0.2 |
||||
ep: int = 1600 |
||||
wp_ep: int = 40 |
||||
clip: int = 5. |
||||
opt: str = '' |
||||
ada: float = 0. |
||||
|
||||
# data hyperparameters |
||||
data_set: str = 'imn' |
||||
rrc: float = 0.67 |
||||
bs: int = 4096 |
||||
num_workers: int = 8 |
||||
|
||||
# would be added during runtime |
||||
cmd: str = '' |
||||
commit_id: str = '' |
||||
commit_msg: str = '' |
||||
last_loss = 1e9 # [would be changed in `main.py`] |
||||
cur_phase: str = '' # [would be changed in `main.py`] |
||||
cur_ep: str = '' # [would be changed in `main.py`] |
||||
remain_time: str = '' # [would be changed in `main.py`] |
||||
finish_time: str = '' # [would be changed in `main.py`] |
||||
|
||||
first_logging: bool = True |
||||
|
||||
@property |
||||
def is_convnext(self): |
||||
return 'convnext' in self.model or 'cnx' in self.model |
||||
|
||||
@property |
||||
def is_resnet(self): |
||||
return 'res' in self.model or 'res' in self.model_alias |
||||
|
||||
def __str__(self): |
||||
return re.sub(r"(\[LE-FT\]:\s*)('\s+')?", r'\1', super(Args, self).__str__()) |
||||
|
||||
def log_epoch(self): |
||||
if not dist.is_local_master(): |
||||
return |
||||
|
||||
if self.first_logging: |
||||
self.first_logging = False |
||||
with open(self.log_txt_name, 'w') as fp: |
||||
json.dump({ |
||||
'name': self.exp_name, 'cmd': self.cmd, 'commit_id': self.commit_id, |
||||
'model': self.model, 'opt': self.opt, |
||||
}, fp) |
||||
print('', end='\n', file=fp) |
||||
|
||||
with open(self.log_txt_name, 'a') as fp: |
||||
json.dump({ |
||||
'cur': self.cur_phase, 'cur_ep': self.cur_ep, |
||||
'last_L': self.last_loss, |
||||
'rema': self.remain_time, 'fini': self.finish_time, |
||||
}, fp) |
||||
|
||||
|
||||
def init_dist_and_get_args(): |
||||
from utils import misc |
||||
from models import model_alias_to_fullname, model_fullname_to_alias |
||||
|
||||
# initialize |
||||
args = Args(explicit_bool=True).parse_args() |
||||
misc.init_distributed_environ(exp_dir=args.exp_dir) |
||||
|
||||
# update args |
||||
args.cmd = ' '.join(sys.argv[1:]) |
||||
args.commit_id = os.popen(f'git rev-parse HEAD').read().strip() |
||||
args.commit_msg = os.popen(f'git log -1').read().strip().splitlines()[-1].strip() |
||||
|
||||
if args.model in model_alias_to_fullname.keys(): |
||||
args.model = model_alias_to_fullname[args.model] |
||||
args.model_alias = model_fullname_to_alias[args.model] |
||||
|
||||
args.device = dist.get_device() |
||||
args.batch_size = args.bs // dist.get_world_size() |
||||
args.glb_batch_size = args.batch_size * dist.get_world_size() |
||||
|
||||
if args.is_resnet: |
||||
args.opt = args.opt or 'lamb' |
||||
args.ada = args.ada or 0.95 |
||||
|
||||
if args.is_convnext: |
||||
args.opt = args.opt or 'lamb' |
||||
args.ada = args.ada or 0.999 |
||||
args.en_de_norm = 'ln' |
||||
|
||||
args.opt = args.opt.lower() |
||||
args.lr = args.base_lr * args.glb_batch_size / 256 |
||||
args.wde = args.wde or args.wd |
||||
|
||||
if args.mask2 < 0: |
||||
args.mask2 = args.mask |
||||
args.mask, args.mask2 = min(args.mask, args.mask2), max(args.mask, args.mask2) |
||||
|
||||
if args.py <= 0: |
||||
args.py = 1 |
||||
|
||||
args.hea = list(map(int, args.hea.split('_'))) |
||||
|
||||
args.log_txt_name = os.path.join(args.exp_dir, 'log.txt') |
||||
|
||||
return args |
@ -0,0 +1,273 @@ |
||||
# Copyright (c) ByteDance, Inc. and its affiliates. |
||||
# All rights reserved. |
||||
# |
||||
# This source code is licensed under the license found in the |
||||
# LICENSE file in the root directory of this source tree. |
||||
|
||||
import datetime |
||||
import functools |
||||
import os |
||||
import subprocess |
||||
import sys |
||||
import time |
||||
from collections import defaultdict, deque |
||||
from typing import Iterator |
||||
|
||||
import numpy as np |
||||
import pytz |
||||
import torch |
||||
import torch.distributed as tdist |
||||
|
||||
import dist |
||||
|
||||
os_system = functools.partial(subprocess.call, shell=True) |
||||
os_system_get_stdout = lambda cmd: subprocess.run(cmd, shell=True, stdout=subprocess.PIPE).stdout.decode('utf-8') |
||||
def os_system_get_stdout_stderr(cmd): |
||||
sp = subprocess.run(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) |
||||
return sp.stdout.decode('utf-8'), sp.stderr.decode('utf-8') |
||||
|
||||
|
||||
def is_pow2n(x): |
||||
return x > 0 and (x & (x - 1) == 0) |
||||
|
||||
|
||||
def time_str(): |
||||
return datetime.datetime.now(tz=pytz.timezone('Asia/Shanghai')).strftime('[%m-%d %H:%M:%S]') |
||||
|
||||
|
||||
def init_distributed_environ(exp_dir): |
||||
dist.initialize() |
||||
dist.barrier() |
||||
|
||||
import torch.backends.cudnn as cudnn |
||||
cudnn.benchmark = True |
||||
cudnn.deterministic = False |
||||
|
||||
_set_print_only_on_master_proc(is_master=dist.is_local_master()) |
||||
if dist.is_local_master() and len(exp_dir): |
||||
sys.stdout, sys.stderr = _SyncPrintToFile(exp_dir, stdout=True), _SyncPrintToFile(exp_dir, stdout=False) |
||||
|
||||
|
||||
def save_checkpoint(fname, args, epoch, performance_desc, model_without_ddp_state, optimizer_state): |
||||
checkpoint_path = os.path.join(args.exp_dir, fname) |
||||
if dist.is_local_master(): |
||||
to_save = { |
||||
'args': str(args), |
||||
'arch': args.model, |
||||
'epoch': epoch, |
||||
'performance_desc': performance_desc, |
||||
'module': model_without_ddp_state, |
||||
'optimizer': optimizer_state, |
||||
} |
||||
torch.save(to_save, checkpoint_path) |
||||
dist.barrier() |
||||
|
||||
|
||||
def load_checkpoint(fname, model_without_ddp, optimizer): |
||||
print(f'[try to resume from file `{fname}`]') |
||||
checkpoint = torch.load(fname, map_location='cpu') |
||||
|
||||
next_ep, performance_desc = checkpoint['epoch'] + 1, checkpoint['performance_desc'] |
||||
missing, unexpected = model_without_ddp.load_state_dict(checkpoint['module'], strict=False) |
||||
print(f'[load_checkpoint] missing_keys={missing}') |
||||
print(f'[load_checkpoint] unexpected_keys={unexpected}') |
||||
print(f'[load_checkpoint] next_ep={next_ep}, performance_desc={performance_desc}') |
||||
|
||||
if 'optimizer' in checkpoint: |
||||
optimizer.load_state_dict(checkpoint['optimizer']) |
||||
|
||||
return next_ep, performance_desc |
||||
|
||||
|
||||
class SmoothedValue(object): |
||||
"""Track a series of values and provide access to smoothed values over a |
||||
window or the global series average. |
||||
""" |
||||
|
||||
def __init__(self, window_size=20, fmt=None): |
||||
if fmt is None: |
||||
fmt = "{median:.4f} ({global_avg:.4f})" |
||||
self.deque = deque(maxlen=window_size) |
||||
self.total = 0.0 |
||||
self.count = 0 |
||||
self.fmt = fmt |
||||
|
||||
def update(self, value, n=1): |
||||
self.deque.append(value) |
||||
self.count += n |
||||
self.total += value * n |
||||
|
||||
def synchronize_between_processes(self): |
||||
""" |
||||
Warning: does not synchronize the deque! |
||||
""" |
||||
t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') |
||||
tdist.barrier() |
||||
tdist.all_reduce(t) |
||||
t = t.tolist() |
||||
self.count = int(t[0]) |
||||
self.total = t[1] |
||||
|
||||
@property |
||||
def median(self): |
||||
d = torch.tensor(list(self.deque)) |
||||
return d.median().item() |
||||
|
||||
@property |
||||
def avg(self): |
||||
d = torch.tensor(list(self.deque), dtype=torch.float32) |
||||
return d.mean().item() |
||||
|
||||
@property |
||||
def global_avg(self): |
||||
return self.total / self.count |
||||
|
||||
@property |
||||
def max(self): |
||||
return max(self.deque) |
||||
|
||||
@property |
||||
def value(self): |
||||
return self.deque[-1] |
||||
|
||||
def time_preds(self, counts): |
||||
remain_secs = counts * self.median |
||||
remain_time = datetime.timedelta(seconds=round(remain_secs)) |
||||
finish_time = time.strftime("%m-%d %H:%M", time.localtime(time.time() + remain_secs)) |
||||
return remain_secs, str(remain_time), finish_time |
||||
|
||||
def __str__(self): |
||||
return self.fmt.format( |
||||
median=self.median, |
||||
avg=self.avg, |
||||
global_avg=self.global_avg, |
||||
max=self.max, |
||||
value=self.value) |
||||
|
||||
|
||||
class MetricLogger(object): |
||||
def __init__(self, delimiter="\t"): |
||||
self.meters = defaultdict(SmoothedValue) |
||||
self.delimiter = delimiter |
||||
|
||||
def update(self, **kwargs): |
||||
for k, v in kwargs.items(): |
||||
if v is None: |
||||
continue |
||||
if isinstance(v, torch.Tensor): |
||||
v = v.item() |
||||
assert isinstance(v, (float, int)) |
||||
self.meters[k].update(v) |
||||
|
||||
def __getattr__(self, attr): |
||||
if attr in self.meters: |
||||
return self.meters[attr] |
||||
if attr in self.__dict__: |
||||
return self.__dict__[attr] |
||||
raise AttributeError("'{}' object has no attribute '{}'".format( |
||||
type(self).__name__, attr)) |
||||
|
||||
def __str__(self): |
||||
loss_str = [] |
||||
for name, meter in self.meters.items(): |
||||
loss_str.append( |
||||
"{}: {}".format(name, str(meter)) |
||||
) |
||||
return self.delimiter.join(loss_str) |
||||
|
||||
def synchronize_between_processes(self): |
||||
for meter in self.meters.values(): |
||||
meter.synchronize_between_processes() |
||||
|
||||
def add_meter(self, name, meter): |
||||
self.meters[name] = meter |
||||
|
||||
def log_every(self, max_iters, itrt, print_freq, header=None): |
||||
print_iters = set(np.linspace(0, max_iters-1, print_freq, dtype=int).tolist()) |
||||
if not header: |
||||
header = '' |
||||
start_time = time.time() |
||||
end = time.time() |
||||
self.iter_time = SmoothedValue(fmt='{avg:.4f}') |
||||
self.data_time = SmoothedValue(fmt='{avg:.4f}') |
||||
space_fmt = ':' + str(len(str(max_iters))) + 'd' |
||||
log_msg = [ |
||||
header, |
||||
'[{0' + space_fmt + '}/{1}]', |
||||
'eta: {eta}', |
||||
'{meters}', |
||||
'time: {time}', |
||||
'data: {data}' |
||||
] |
||||
log_msg = self.delimiter.join(log_msg) |
||||
|
||||
if isinstance(itrt, Iterator) and not hasattr(itrt, 'preload') and not hasattr(itrt, 'set_epoch'): |
||||
for i in range(max_iters): |
||||
obj = next(itrt) |
||||
self.data_time.update(time.time() - end) |
||||
yield obj |
||||
self.iter_time.update(time.time() - end) |
||||
if i in print_iters: |
||||
eta_seconds = self.iter_time.global_avg * (max_iters - i) |
||||
eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) |
||||
print(log_msg.format( |
||||
i, max_iters, eta=eta_string, |
||||
meters=str(self), |
||||
time=str(self.iter_time), data=str(self.data_time))) |
||||
end = time.time() |
||||
else: |
||||
for i, obj in enumerate(itrt): |
||||
self.data_time.update(time.time() - end) |
||||
yield obj |
||||
self.iter_time.update(time.time() - end) |
||||
if i in print_iters: |
||||
eta_seconds = self.iter_time.global_avg * (max_iters - i) |
||||
eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) |
||||
print(log_msg.format( |
||||
i, max_iters, eta=eta_string, |
||||
meters=str(self), |
||||
time=str(self.iter_time), data=str(self.data_time))) |
||||
end = time.time() |
||||
|
||||
total_time = time.time() - start_time |
||||
total_time_str = str(datetime.timedelta(seconds=int(total_time))) |
||||
print('{} Total time: {} ({:.3f} s / it)'.format( |
||||
header, total_time_str, total_time / max_iters)) |
||||
|
||||
|
||||
def _set_print_only_on_master_proc(is_master): |
||||
import builtins as __builtin__ |
||||
|
||||
builtin_print = __builtin__.print |
||||
|
||||
def prt(msg, *args, **kwargs): |
||||
force = kwargs.pop('force', False) |
||||
clean = kwargs.pop('clean', False) |
||||
deeper = kwargs.pop('deeper', False) |
||||
if is_master or force: |
||||
if not clean: |
||||
f_back = sys._getframe().f_back |
||||
if deeper and f_back.f_back is not None: |
||||
f_back = f_back.f_back |
||||
file_desc = f'{f_back.f_code.co_filename:24s}'[-24:] |
||||
msg = f'{time_str()} ({file_desc}, line{f_back.f_lineno:-4d})=> {msg}' |
||||
builtin_print(msg, *args, **kwargs) |
||||
|
||||
__builtin__.print = prt |
||||
|
||||
|
||||
class _SyncPrintToFile(object): |
||||
def __init__(self, exp_dir, stdout=True): |
||||
self.terminal = sys.stdout if stdout else sys.stderr |
||||
fname = os.path.join(exp_dir, 'stdout.txt' if stdout else 'stderr.txt') |
||||
self.log = open(fname, 'w') |
||||
self.log.flush() |
||||
|
||||
def write(self, message): |
||||
self.terminal.write(message) |
||||
self.log.write(message) |
||||
self.log.flush() |
||||
|
||||
def flush(self): |
||||
self.terminal.flush() |
||||
self.log.flush() |
@ -0,0 +1,160 @@ |
||||
# Copyright (c) ByteDance, Inc. and its affiliates. |
||||
# All rights reserved. |
||||
# |
||||
# This source code is licensed under the license found in the |
||||
# LICENSE file in the root directory of this source tree. |
||||
# |
||||
# This file is basically a copy to: https://github.com/rwightman/pytorch-image-models/blob/v0.5.4/timm/optim/lamb.py |
||||
|
||||
|
||||
""" PyTorch Lamb optimizer w/ behaviour similar to NVIDIA FusedLamb |
||||
This optimizer code was adapted from the following (starting with latest) |
||||
* https://github.com/HabanaAI/Model-References/blob/2b435114fe8e31f159b1d3063b8280ae37af7423/PyTorch/nlp/bert/pretraining/lamb.py |
||||
* https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/LanguageModeling/Transformer-XL/pytorch/lamb.py |
||||
* https://github.com/cybertronai/pytorch-lamb |
||||
Use FusedLamb if you can (GPU). The reason for including this variant of Lamb is to have a version that is |
||||
similar in behaviour to APEX FusedLamb if you aren't using NVIDIA GPUs or cannot install/use APEX. |
||||
In addition to some cleanup, this Lamb impl has been modified to support PyTorch XLA and has been tested on TPU. |
||||
Original copyrights for above sources are below. |
||||
Modifications Copyright 2021 Ross Wightman |
||||
""" |
||||
import math |
||||
|
||||
import torch |
||||
from torch.optim.optimizer import Optimizer |
||||
|
||||
|
||||
class TimmLAMB(Optimizer): |
||||
"""Implements a pure pytorch variant of FuseLAMB (NvLamb variant) optimizer from apex.optimizers.FusedLAMB |
||||
reference: https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/LanguageModeling/Transformer-XL/pytorch/lamb.py |
||||
|
||||
LAMB was proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_. |
||||
|
||||
Arguments: |
||||
params (iterable): iterable of parameters to optimize or dicts defining parameter groups. |
||||
lr (float, optional): learning rate. (default: 1e-3) |
||||
betas (Tuple[float, float], optional): coefficients used for computing |
||||
running averages of gradient and its norm. (default: (0.9, 0.999)) |
||||
eps (float, optional): term added to the denominator to improve |
||||
numerical stability. (default: 1e-8) |
||||
weight_decay (float, optional): weight decay (L2 penalty) (default: 0) |
||||
grad_averaging (bool, optional): whether apply (1-beta2) to grad when |
||||
calculating running averages of gradient. (default: True) |
||||
max_grad_norm (float, optional): value used to clip global grad norm (default: 1.0) |
||||
trust_clip (bool): enable LAMBC trust ratio clipping (default: False) |
||||
always_adapt (boolean, optional): Apply adaptive learning rate to 0.0 |
||||
weight decay parameter (default: False) |
||||
|
||||
.. _Large Batch Optimization for Deep Learning - Training BERT in 76 minutes: |
||||
https://arxiv.org/abs/1904.00962 |
||||
.. _On the Convergence of Adam and Beyond: |
||||
https://openreview.net/forum?id=ryQu7f-RZ |
||||
""" |
||||
|
||||
def __init__( |
||||
self, params, lr=1e-3, bias_correction=True, betas=(0.9, 0.999), eps=1e-6, |
||||
weight_decay=0.01, grad_averaging=True, max_grad_norm=2.0, trust_clip=False, always_adapt=False): |
||||
defaults = dict( |
||||
lr=lr, bias_correction=bias_correction, betas=betas, eps=eps, weight_decay=weight_decay, |
||||
grad_averaging=grad_averaging, max_grad_norm=max_grad_norm, |
||||
trust_clip=trust_clip, always_adapt=always_adapt) |
||||
super().__init__(params, defaults) |
||||
print(f'[lamb1] max_grad_norm={max_grad_norm}') |
||||
self.global_grad_norm = 0 |
||||
|
||||
@torch.no_grad() |
||||
def step(self, closure=None): |
||||
"""Performs a single optimization step. |
||||
Arguments: |
||||
closure (callable, optional): A closure that reevaluates the model |
||||
and returns the loss. |
||||
""" |
||||
loss = None |
||||
if closure is not None: |
||||
with torch.enable_grad(): |
||||
loss = closure() |
||||
|
||||
device = self.param_groups[0]['params'][0].device |
||||
one_tensor = torch.tensor(1.0, device=device) # because torch.where doesn't handle scalars correctly |
||||
global_grad_norm = torch.zeros(1, device=device) |
||||
for group in self.param_groups: |
||||
for p in group['params']: |
||||
if p.grad is None: |
||||
continue |
||||
grad = p.grad |
||||
if grad.is_sparse: |
||||
raise RuntimeError('Lamb does not support sparse gradients, consider SparseAdam instad.') |
||||
global_grad_norm.add_(grad.pow(2).sum()) |
||||
|
||||
global_grad_norm = torch.sqrt(global_grad_norm) |
||||
self.global_grad_norm = global_grad_norm.item() |
||||
max_grad_norm = torch.tensor(self.defaults['max_grad_norm'], device=device) |
||||
clip_global_grad_norm = 1 / torch.where( |
||||
global_grad_norm > max_grad_norm, |
||||
global_grad_norm / max_grad_norm, |
||||
one_tensor) |
||||
|
||||
for group in self.param_groups: |
||||
bias_correction = 1 if group['bias_correction'] else 0 |
||||
beta1, beta2 = group['betas'] |
||||
grad_averaging = 1 if group['grad_averaging'] else 0 |
||||
beta3 = 1 - beta1 if grad_averaging else 1.0 |
||||
|
||||
# assume same step across group now to simplify things |
||||
# per parameter step can be easily support by making it tensor, or pass list into kernel |
||||
if 'step' in group: |
||||
group['step'] += 1 |
||||
else: |
||||
group['step'] = 1 |
||||
|
||||
if bias_correction: |
||||
bias_correction1 = 1 - beta1 ** group['step'] |
||||
bias_correction2 = 1 - beta2 ** group['step'] |
||||
else: |
||||
bias_correction1, bias_correction2 = 1.0, 1.0 |
||||
|
||||
for p in group['params']: |
||||
if p.grad is None: |
||||
continue |
||||
grad = p.grad.mul_(clip_global_grad_norm) |
||||
state = self.state[p] |
||||
|
||||
# State initialization |
||||
if len(state) == 0: |
||||
# Exponential moving average of gradient valuesa |
||||
state['exp_avg'] = torch.zeros_like(p) |
||||
# Exponential moving average of squared gradient values |
||||
state['exp_avg_sq'] = torch.zeros_like(p) |
||||
|
||||
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] |
||||
|
||||
# Decay the first and second moment running average coefficient |
||||
exp_avg.mul_(beta1).add_(grad, alpha=beta3) # m_t |
||||
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) # v_t |
||||
|
||||
denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps']) |
||||
update = (exp_avg / bias_correction1).div_(denom) |
||||
|
||||
weight_decay = group['weight_decay'] |
||||
if weight_decay != 0: |
||||
update.add_(p, alpha=weight_decay) |
||||
|
||||
if weight_decay != 0 or group['always_adapt']: |
||||
# Layer-wise LR adaptation. By default, skip adaptation on parameters that are |
||||
# excluded from weight decay, unless always_adapt == True, then always enabled. |
||||
w_norm = p.norm(2.0) |
||||
g_norm = update.norm(2.0) |
||||
# FIXME nested where required since logical and/or not working in PT XLA |
||||
trust_ratio = torch.where( |
||||
w_norm > 0, |
||||
torch.where(g_norm > 0, w_norm / g_norm, one_tensor), |
||||
one_tensor, |
||||
) |
||||
if group['trust_clip']: |
||||
# LAMBC trust clipping, upper bound fixed at one |
||||
trust_ratio = torch.minimum(trust_ratio, one_tensor) |
||||
update.mul_(trust_ratio) |
||||
|
||||
p.add_(update, alpha=-group['lr']) |
||||
|
||||
return loss |
Loading…
Reference in new issue