一文看懂DETR(二)

在这里插入图片描述
训练流程
1.输入图像经过CNN的backbone获得32倍下采样的深度特征;
2.将图片给拉直形成token,并添加位置编码送入encoder中;
3.将encoder的输出以及Object Query作为decoder的输入得到解码特征;
4.将解码后的特征传入FFN得到预测特征;
5.根据预测特征计算cost matrix,并由匈牙利算法匹配GT,获得正负样本;
6.根据正负样本计算分类与回归loss。

代码实现

书接上回,我们从代码层面讲解了训练步骤1,下面将详细解读DETR如何在encoder中处理token与位置编码的。def forward(self, x, mask, query_embed, pos_embed):可以看到输入transformer forwad函数中的x表示步骤1的32倍下采样特征([2,256,22,38]维度),query_embed是用于decoder的可学习编码(即论文中的object query维度为[100,256]),pos_embed表示位置编码维度与x一致,mask表示图片有效区域(维度为[2,22,38])。

首先,需要将x,pos_embed,mask的h,w二维结构拉直成h*w的一维结构,即将x特征拉直成tokens。然后,x(维度[836,2,256])作为query,query_pos=pos_embed(维度[836,2,256]),query_key_padding_mask=mask(维度[2,836]),将三者送入self.encoder中。

class Transformer(BaseModule):"""Implements the DETR transformer.Following the official DETR implementation, this module copy-pastefrom torch.nn.Transformer with modifications:* positional encodings are passed in MultiheadAttention* extra LN at the end of encoder is removed* decoder returns a stack of activations from all decoding layersSee `paper: End-to-End Object Detection with Transformers<https://arxiv.org/pdf/2005.12872>`_ for details.Args:encoder (`mmcv.ConfigDict` | Dict): Config ofTransformerEncoder. Defaults to None.decoder ((`mmcv.ConfigDict` | Dict)): Config ofTransformerDecoder. Defaults to Noneinit_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.Defaults to None."""def __init__(self, encoder=None, decoder=None, init_cfg=None):super(Transformer, self).__init__(init_cfg=init_cfg)self.encoder = build_transformer_layer_sequence(encoder)self.decoder = build_transformer_layer_sequence(decoder)self.embed_dims = self.encoder.embed_dimsdef init_weights(self):# follow the official DETR to init parametersfor m in self.modules():if hasattr(m, 'weight') and m.weight.dim() > 1:xavier_init(m, distribution='uniform')self._is_init = Truedef forward(self, x, mask, query_embed, pos_embed):"""Forward function for `Transformer`.Args:x (Tensor): Input query with shape [bs, c, h, w] wherec = embed_dims.mask (Tensor): The key_padding_mask used for encoder and decoder,with shape [bs, h, w].query_embed (Tensor): The query embedding for decoder, with shape[num_query, c].pos_embed (Tensor): The positional encoding for encoder anddecoder, with the same shape as `x`.Returns:tuple[Tensor]: results of decoder containing the following tensor.- out_dec: Output from decoder. If return_intermediate_dec \is True output has shape [num_dec_layers, bs,num_query, embed_dims], else has shape [1, bs, \num_query, embed_dims].- memory: Output results from encoder, with shape \[bs, embed_dims, h, w]."""bs, c, h, w = x.shape# use `view` instead of `flatten` for dynamically exporting to ONNXx = x.view(bs, c, -1).permute(2, 0, 1)  # [bs, c, h, w] -> [h*w, bs, c]pos_embed = pos_embed.view(bs, c, -1).permute(2, 0, 1)query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)  # [num_query, dim] -> [num_query, bs, dim]mask = mask.view(bs, -1)  # [bs, h, w] -> [bs, h*w]memory = self.encoder(query=x,key=None,value=None,query_pos=pos_embed,query_key_padding_mask=mask)target = torch.zeros_like(query_embed)# out_dec: [num_layers, num_query, bs, dim]out_dec = self.decoder(query=target,key=memory,value=memory,key_pos=pos_embed,query_pos=query_embed,key_padding_mask=mask)out_dec = out_dec.transpose(1, 2)memory = memory.permute(1, 2, 0).reshape(bs, c, h, w)return out_dec, memory

self.encoder在transformer初始化时已经注册完毕,是来自class DetrTransformerEncoder(TransformerLayerSequence):的对象。从代码forward中可以看到,它没有单独实现这部分代码,而是直接使用父类DetrTransformerEncoder的forward。

class DetrTransformerEncoder(TransformerLayerSequence):"""TransformerEncoder of DETR.Args:post_norm_cfg (dict): Config of last normalization layer. Default:`LN`. Only used when `self.pre_norm` is `True`"""def __init__(self, *args, post_norm_cfg=dict(type='LN'), **kwargs):super(DetrTransformerEncoder, self).__init__(*args, **kwargs)if post_norm_cfg is not None:self.post_norm = build_norm_layer(post_norm_cfg, self.embed_dims)[1] if self.pre_norm else Noneelse:assert not self.pre_norm, f'Use prenorm in ' \f'{self.__class__.__name__},' \f'Please specify post_norm_cfg'self.post_norm = Nonedef forward(self, *args, **kwargs):"""Forward function for `TransformerCoder`.Returns:Tensor: forwarded results with shape [num_query, bs, embed_dims]."""x = super(DetrTransformerEncoder, self).forward(*args, **kwargs)if self.post_norm is not None:x = self.post_norm(x)return x

在对TransformerLayerSequence初始化时,通过num_layer(config中设置为6)设置encoder的层数,即self.layers的层数。transformerlayers是从config传入的dict,如下所示,构成self.layers的是BaseTransformerLayer这个类,其中self.pre_norm=False, self.embed_dims=256, 这里因为嵌套的类比较多,看起来有些复杂,我们慢慢剖析。

{‘type’: ‘BaseTransformerLayer’, ‘attn_cfgs’: [{‘type’: ‘MultiheadAttention’, ‘embed_dims’: 256, ‘num_heads’: 8, ‘dropout’: 0.1}], ‘feedforward_channels’: 2048, ‘ffn_dropout’: 0.1, ‘operation_order’: (‘self_attn’, ‘norm’, ‘ffn’, ‘norm’)}

class TransformerLayerSequence(BaseModule):"""Base class for TransformerEncoder and TransformerDecoder in visiontransformer.As base-class of Encoder and Decoder in vision transformer.Support customization such as specifying different kindof `transformer_layer` in `transformer_coder`.Args:transformerlayer (list[obj:`mmcv.ConfigDict`] |obj:`mmcv.ConfigDict`): Config of transformerlayerin TransformerCoder. If it is obj:`mmcv.ConfigDict`,it would be repeated `num_layer` times to alist[`mmcv.ConfigDict`]. Default: None.num_layers (int): The number of `TransformerLayer`. Default: None.init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.Default: None."""def __init__(self, transformerlayers=None, num_layers=None, init_cfg=None):super(TransformerLayerSequence, self).__init__(init_cfg)if isinstance(transformerlayers, dict):transformerlayers = [copy.deepcopy(transformerlayers) for _ in range(num_layers)]else:assert isinstance(transformerlayers, list) and \len(transformerlayers) == num_layersself.num_layers = num_layersself.layers = ModuleList()for i in range(num_layers):self.layers.append(build_transformer_layer(transformerlayers[i]))self.embed_dims = self.layers[0].embed_dimsself.pre_norm = self.layers[0].pre_normdef forward(self,query,key,value,query_pos=None,key_pos=None,attn_masks=None,query_key_padding_mask=None,key_padding_mask=None,**kwargs):"""Forward function for `TransformerCoder`.Args:query (Tensor): Input query with shape`(num_queries, bs, embed_dims)`.key (Tensor): The key tensor with shape`(num_keys, bs, embed_dims)`.value (Tensor): The value tensor with shape`(num_keys, bs, embed_dims)`.query_pos (Tensor): The positional encoding for `query`.Default: None.key_pos (Tensor): The positional encoding for `key`.Default: None.attn_masks (List[Tensor], optional): Each element is 2D Tensorwhich is used in calculation of corresponding attention inoperation_order. Default: None.query_key_padding_mask (Tensor): ByteTensor for `query`, withshape [bs, num_queries]. Only used in self-attentionDefault: None.key_padding_mask (Tensor): ByteTensor for `query`, withshape [bs, num_keys]. Default: None.Returns:Tensor:  results with shape [num_queries, bs, embed_dims]."""for layer in self.layers:query = layer(query,key,value,query_pos=query_pos,key_pos=key_pos,attn_masks=attn_masks,query_key_padding_mask=query_key_padding_mask,key_padding_mask=key_padding_mask,**kwargs)return query

在TransformerLayerSequence的forward中,我们看到在init初始化好的self.layers,它是个Modulelist,里面重复了6次layer,layer是BaseTransformerLayer这个类的对象。

在BaseTransformerLayer初始化里,self.batch_first=False,operation_order里面按序存放了算子名称,如下所示,self_attn指MultiheadAttention,norm表示layer norm,ffn就是FFN。num_attn表示operation_order的self_attn自注意力和cross_attn交叉注意力出现的次数。在encoder中,operation_order存放了一个self_attn,self_attn与cross_attn具体区别会在下面讲到。

(‘self_attn’, ‘norm’, ‘ffn’, ‘norm’)

attn_cfgs存放了如下参数,添加batch_first属性为False,并使用attn_cfgs初始化attention,attention是来自MultiheadAttention的对象,其中token的维度为256,多头数为8,dropout系数是0.1。

[{‘type’: ‘MultiheadAttention’, ‘embed_dims’: 256, ‘num_heads’: 8, ‘dropout’: 0.1}]

接下来num_ffns表示operation_order中FFN的个数,同样构建ffn_cfgs存放FFN需要的参数,其中embed_dims与attention保持一致即256,self.norms用来指向layer norm。

class BaseTransformerLayer(BaseModule):"""Base `TransformerLayer` for vision transformer.It can be built from `mmcv.ConfigDict` and support more flexiblecustomization, for example, using any number of `FFN or LN ` anduse different kinds of `attention` by specifying a list of `ConfigDict`named `attn_cfgs`. It is worth mentioning that it supports `prenorm`when you specifying `norm` as the first element of `operation_order`.More details about the `prenorm`: `On Layer Normalization in theTransformer Architecture <https://arxiv.org/abs/2002.04745>`_ .Args:attn_cfgs (list[`mmcv.ConfigDict`] | obj:`mmcv.ConfigDict` | None )):Configs for `self_attention` or `cross_attention` modules,The order of the configs in the list should be consistent withcorresponding attentions in operation_order.If it is a dict, all of the attention modules in operation_orderwill be built with this config. Default: None.ffn_cfgs (list[`mmcv.ConfigDict`] | obj:`mmcv.ConfigDict` | None )):Configs for FFN, The order of the configs in the list should beconsistent with corresponding ffn in operation_order.If it is a dict, all of the attention modules in operation_orderwill be built with this config.operation_order (tuple[str]): The execution order of operationin transformer. Such as ('self_attn', 'norm', 'ffn', 'norm').Support `prenorm` when you specifying first element as `norm`.Default:None.norm_cfg (dict): Config dict for normalization layer.Default: dict(type='LN').init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.Default: None.batch_first (bool): Key, Query and Value are shapeof (batch, n, embed_dim)or (n, batch, embed_dim). Default to False."""def __init__(self,attn_cfgs=None,ffn_cfgs=dict(type='FFN',embed_dims=256,feedforward_channels=1024,num_fcs=2,ffn_drop=0.,act_cfg=dict(type='ReLU', inplace=True),),operation_order=None,norm_cfg=dict(type='LN'),init_cfg=None,batch_first=False,**kwargs):super(BaseTransformerLayer, self).__init__(init_cfg)self.batch_first = batch_firstnum_attn = operation_order.count('self_attn') + operation_order.count('cross_attn')if isinstance(attn_cfgs, dict):attn_cfgs = [copy.deepcopy(attn_cfgs) for _ in range(num_attn)]else:assert num_attn == len(attn_cfgs), f'The length ' \f'of attn_cfg {num_attn} is ' \f'not consistent with the number of attention' \f'in operation_order {operation_order}.'self.num_attn = num_attnself.operation_order = operation_orderself.norm_cfg = norm_cfgself.pre_norm = operation_order[0] == 'norm'self.attentions = ModuleList()index = 0for operation_name in operation_order:if operation_name in ['self_attn', 'cross_attn']:if 'batch_first' in attn_cfgs[index]:assert self.batch_first == attn_cfgs[index]['batch_first']else:attn_cfgs[index]['batch_first'] = self.batch_firstattention = build_attention(attn_cfgs[index])# Some custom attentions used as `self_attn`# or `cross_attn` can have different behavior.attention.operation_name = operation_nameself.attentions.append(attention)index += 1self.embed_dims = self.attentions[0].embed_dimsself.ffns = ModuleList()num_ffns = operation_order.count('ffn')if isinstance(ffn_cfgs, dict):ffn_cfgs = ConfigDict(ffn_cfgs)if isinstance(ffn_cfgs, dict):ffn_cfgs = [copy.deepcopy(ffn_cfgs) for _ in range(num_ffns)]assert len(ffn_cfgs) == num_ffnsfor ffn_index in range(num_ffns):if 'embed_dims' not in ffn_cfgs[ffn_index]:ffn_cfgs[ffn_index]['embed_dims'] = self.embed_dimselse:assert ffn_cfgs[ffn_index]['embed_dims'] == self.embed_dimsself.ffns.append(build_feedforward_network(ffn_cfgs[ffn_index],dict(type='FFN')))self.norms = ModuleList()num_norms = operation_order.count('norm')for _ in range(num_norms):self.norms.append(build_norm_layer(norm_cfg, self.embed_dims)[1])

BaseTransformerLayer初始化介绍完之后,程序继续来到其forward中,在上面分析self.encoder的forward时,我们已经知道传到这里的参数:x作为query,query_pos=pos_embed,query_key_padding_mask=mask。我们在self.operation_order的循环里首先进入self_attn这个if分支,即进入MultiheadAttention的对象中,这里要注意 temp_key = temp_value = query,即qkv都是x。

 def forward(self,query,key=None,value=None,query_pos=None,key_pos=None,attn_masks=None,query_key_padding_mask=None,key_padding_mask=None,**kwargs):"""Forward function for `TransformerDecoderLayer`.**kwargs contains some specific arguments of attentions.Args:query (Tensor): The input query with shape[num_queries, bs, embed_dims] ifself.batch_first is False, else[bs, num_queries embed_dims].key (Tensor): The key tensor with shape [num_keys, bs,embed_dims] if self.batch_first is False, else[bs, num_keys, embed_dims] .value (Tensor): The value tensor with same shape as `key`.query_pos (Tensor): The positional encoding for `query`.Default: None.key_pos (Tensor): The positional encoding for `key`.Default: None.attn_masks (List[Tensor] | None): 2D Tensor used incalculation of corresponding attention. The length ofit should equal to the number of `attention` in`operation_order`. Default: None.query_key_padding_mask (Tensor): ByteTensor for `query`, withshape [bs, num_queries]. Only used in `self_attn` layer.Defaults to None.key_padding_mask (Tensor): ByteTensor for `query`, withshape [bs, num_keys]. Default: None.Returns:Tensor: forwarded results with shape [num_queries, bs, embed_dims]."""norm_index = 0attn_index = 0ffn_index = 0identity = queryif attn_masks is None:attn_masks = [None for _ in range(self.num_attn)]elif isinstance(attn_masks, torch.Tensor):attn_masks = [copy.deepcopy(attn_masks) for _ in range(self.num_attn)]warnings.warn(f'Use same attn_mask in all attentions in 'f'{self.__class__.__name__} ')else:assert len(attn_masks) == self.num_attn, f'The length of ' \f'attn_masks {len(attn_masks)} must be equal ' \f'to the number of attention in ' \f'operation_order {self.num_attn}'for layer in self.operation_order:if layer == 'self_attn':temp_key = temp_value = queryquery = self.attentions[attn_index](query,temp_key,temp_value,identity if self.pre_norm else None,query_pos=query_pos,key_pos=query_pos,attn_mask=attn_masks[attn_index],key_padding_mask=query_key_padding_mask,**kwargs)attn_index += 1identity = queryelif layer == 'norm':query = self.norms[norm_index](query)norm_index += 1elif layer == 'cross_attn':query = self.attentions[attn_index](query,key,value,identity if self.pre_norm else None,query_pos=query_pos,key_pos=key_pos,attn_mask=attn_masks[attn_index],key_padding_mask=key_padding_mask,**kwargs)attn_index += 1identity = queryelif layer == 'ffn':query = self.ffns[ffn_index](query, identity if self.pre_norm else None)ffn_index += 1return query

看到MultiheadAttention的初始化,上面已经讲过这些参数,这里就没啥好介绍的,略过了。。。

class MultiheadAttention(BaseModule):"""A wrapper for ``torch.nn.MultiheadAttention``.This module implements MultiheadAttention with identity connection,and positional encoding  is also passed as input.Args:embed_dims (int): The embedding dimension.num_heads (int): Parallel attention heads.attn_drop (float): A Dropout layer on attn_output_weights.Default: 0.0.proj_drop (float): A Dropout layer after `nn.MultiheadAttention`.Default: 0.0.dropout_layer (obj:`ConfigDict`): The dropout_layer usedwhen adding the shortcut.init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.Default: None.batch_first (bool): When it is True,  Key, Query and Value are shape of(batch, n, embed_dim), otherwise (n, batch, embed_dim).Default to False."""def __init__(self,embed_dims,num_heads,attn_drop=0.,proj_drop=0.,dropout_layer=dict(type='Dropout', drop_prob=0.),init_cfg=None,batch_first=False,**kwargs):super(MultiheadAttention, self).__init__(init_cfg)if 'dropout' in kwargs:warnings.warn('The arguments `dropout` in MultiheadAttention ''has been deprecated, now you can separately ''set `attn_drop`(float), proj_drop(float), ''and `dropout_layer`(dict) ', DeprecationWarning)attn_drop = kwargs['dropout']dropout_layer['drop_prob'] = kwargs.pop('dropout')self.embed_dims = embed_dimsself.num_heads = num_headsself.batch_first = batch_firstself.attn = nn.MultiheadAttention(embed_dims, num_heads, attn_drop,**kwargs)self.proj_drop = nn.Dropout(proj_drop)self.dropout_layer = build_dropout(dropout_layer) if dropout_layer else nn.Identity()

在看到MultiheadAttention的forward,qkv都是x(维度[836,2,256]),query_pos=key_pos=pos_embed,在q,k上都加上位置编码,将图片位置信息与token结合。self.attn就是多头自注意力,细节在ViT中讲过,这里就不赘述了,最后,return的时候加上了残差以及droupout。

这里进行的是self_attn自注意力,即qkv=x,同时q,k需要加上位置编码。

从MultiheadAttention的forward出来后,又重新进入BaseTransformerLayer的forward,继续进入FFN以及layer norm。这个过程会重复6次。至此,self.encoder的工作就结束了。

def forward(self,query,key=None,value=None,identity=None,query_pos=None,key_pos=None,attn_mask=None,key_padding_mask=None,**kwargs):"""Forward function for `MultiheadAttention`.**kwargs allow passing a more general data flow when combiningwith other operations in `transformerlayer`.Args:query (Tensor): The input query with shape [num_queries, bs,embed_dims] if self.batch_first is False, else[bs, num_queries embed_dims].key (Tensor): The key tensor with shape [num_keys, bs,embed_dims] if self.batch_first is False, else[bs, num_keys, embed_dims] .If None, the ``query`` will be used. Defaults to None.value (Tensor): The value tensor with same shape as `key`.Same in `nn.MultiheadAttention.forward`. Defaults to None.If None, the `key` will be used.identity (Tensor): This tensor, with the same shape as x,will be used for the identity link.If None, `x` will be used. Defaults to None.query_pos (Tensor): The positional encoding for query, withthe same shape as `x`. If not None, it willbe added to `x` before forward function. Defaults to None.key_pos (Tensor): The positional encoding for `key`, with thesame shape as `key`. Defaults to None. If not None, it willbe added to `key` before forward function. If None, and`query_pos` has the same shape as `key`, then `query_pos`will be used for `key_pos`. Defaults to None.attn_mask (Tensor): ByteTensor mask with shape [num_queries,num_keys]. Same in `nn.MultiheadAttention.forward`.Defaults to None.key_padding_mask (Tensor): ByteTensor with shape [bs, num_keys].Defaults to None.Returns:Tensor: forwarded results with shape[num_queries, bs, embed_dims]if self.batch_first is False, else[bs, num_queries embed_dims]."""if key is None:key = queryif value is None:value = keyif identity is None:identity = queryif key_pos is None:if query_pos is not None:# use query_pos if key_pos is not availableif query_pos.shape == key.shape:key_pos = query_poselse:warnings.warn(f'position encoding of key is'f'missing in {self.__class__.__name__}.')if query_pos is not None:query = query + query_posif key_pos is not None:key = key + key_pos# Because the dataflow('key', 'query', 'value') of# ``torch.nn.MultiheadAttention`` is (num_query, batch,# embed_dims), We should adjust the shape of dataflow from# batch_first (batch, num_query, embed_dims) to num_query_first# (num_query ,batch, embed_dims), and recover ``attn_output``# from num_query_first to batch_first.if self.batch_first:query = query.transpose(0, 1)key = key.transpose(0, 1)value = value.transpose(0, 1)out = self.attn(query=query,key=key,value=value,attn_mask=attn_mask,key_padding_mask=key_padding_mask)[0]if self.batch_first:out = out.transpose(0, 1)return identity + self.dropout_layer(self.proj_drop(out))

程序返回到transformer的forward中,由self.encoder输出的memory维度为[836,2,256],它重新作为self.decoder的k与v。在self.decoder中,它的q=target,target是维度为[836,2,256]且值为0的tensor,key_pos=pos_embed,query_pos=query_embed, key_padding_mask=mask,这些参数传入到BaseTransformerLayer的forward中。

		memory = self.encoder(query=x,key=None,value=None,query_pos=pos_embed,query_key_padding_mask=mask)target = torch.zeros_like(query_embed)# out_dec: [num_layers, num_query, bs, dim]out_dec = self.decoder(query=target,key=memory,value=memory,key_pos=pos_embed,query_pos=query_embed,key_padding_mask=mask)out_dec = out_dec.transpose(1, 2)memory = memory.permute(1, 2, 0).reshape(bs, c, h, w)return out_dec, memory

self.decoder中的self.operation_order如下所示,包含了self_attn自注意力与cross_attn交叉注意力模块,num_layers=6。

(‘self_attn’, ‘norm’, ‘cross_attn’, ‘norm’, ‘ffn’, ‘norm’)

self_attn自注意力模块上面介绍过,需要注意的是,self.decoder第一次self_attn时,q=target(target维度为[836,2,256],值为0),temp_key = temp_value = query,在进入self.attentions后,qkv=0。

进入cross_attn分支,query来自self_attn,key=value=memory,即k,v是来自self.encoder生成的特征(维度为[836,2,256])。cross_attn交叉注意力与self_attn自注意力的区别在于,交叉注意力的query不等于key,value,qkv来源不同。 需要注意的是,这里的query_pos是object query(可学习编码[100,256]),在attention中,query需要与object query相加生成新的query。可以这么理解,在多头注意力中,加入了object query后,就如同添加100个可学习的anchor,这些anchor会与k计算相似度,并通过softmax归一化后作为value的系数,形成新的特征。通过学习,object query可以分别总结归纳出他们各自感兴趣的位置信息,从而帮助DETR完成目标检测的任务。

if query_pos is not None:
query = query + query_pos
if key_pos is not None:
key = key + key_pos

			if layer == 'self_attn':temp_key = temp_value = queryquery = self.attentions[attn_index](query,temp_key,temp_value,identity if self.pre_norm else None,query_pos=query_pos,key_pos=query_pos,attn_mask=attn_masks[attn_index],key_padding_mask=query_key_padding_mask,**kwargs)attn_index += 1identity = queryelif layer == 'norm':query = self.norms[norm_index](query)norm_index += 1elif layer == 'cross_attn':query = self.attentions[attn_index](query,key,value,identity if self.pre_norm else None,query_pos=query_pos,key_pos=key_pos,attn_mask=attn_masks[attn_index],key_padding_mask=key_padding_mask,**kwargs)attn_index += 1identity = query

由于DETR没有先验anchor,收敛难度大。为了更好优化梯度,将6层decoder输出的feature都保存下来,并分别分配给 auxiliary decoding losses。至此,self.decoder讲解完毕。

class DetrTransformerDecoder(TransformerLayerSequence):"""Implements the decoder in DETR transformer.Args:return_intermediate (bool): Whether to return intermediate outputs.post_norm_cfg (dict): Config of last normalization layer. Default:`LN`."""def __init__(self,*args,post_norm_cfg=dict(type='LN'),return_intermediate=False,**kwargs):super(DetrTransformerDecoder, self).__init__(*args, **kwargs)self.return_intermediate = return_intermediateif post_norm_cfg is not None:self.post_norm = build_norm_layer(post_norm_cfg,self.embed_dims)[1]else:self.post_norm = Nonedef forward(self, query, *args, **kwargs):"""Forward function for `TransformerDecoder`.Args:query (Tensor): Input query with shape`(num_query, bs, embed_dims)`.Returns:Tensor: Results with shape [1, num_query, bs, embed_dims] whenreturn_intermediate is `False`, otherwise it has shape[num_layers, num_query, bs, embed_dims]."""if not self.return_intermediate:x = super().forward(query, *args, **kwargs)if self.post_norm:x = self.post_norm(x)[None]return xintermediate = []for layer in self.layers:query = layer(query, *args, **kwargs)if self.return_intermediate:if self.post_norm is not None:intermediate.append(self.post_norm(query))else:intermediate.append(query)return torch.stack(intermediate)

整个DETR的模型框架到这里就告一段落了,下面我们继续讲解label assignment以及相应的loss,看看DETR是如何优雅的解决端到端目标检测任务的。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/114031.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Ubantu安装mongodb,开启远程访问和认证

最近因为项目原因需要在阿里云服务器上部署MongoDB&#xff0c;操作系统为Ubuntu&#xff0c;网上查阅了一些资料&#xff0c;特此记录一下步骤。 1.运行apt-get install mongodb命令安装MongoDB服务&#xff08;如果提示找不到该package&#xff0c;说明apt-get的资源库版本比…

研磨设计模式day15策略模式

场景 问题描述 经常会有这样的需要&#xff0c;在不同的时候&#xff0c;要使用不同的计算方式。 解决方案 策略模式 定义&#xff1a; 解决思路&#xff1a;

【Mysql问题集锦】:Can‘t create table ‘#sql-58d7_431d‘ (errno: 28)

问题描述&#xff1a; 问题原因&#xff1a; OSError: [Errno 28] No space left on device&#xff0c;即&#xff1a;磁盘空间不足&#xff0c;无法创建文件。因此&#xff0c;导致Mysql无法执行SQL语句。 问题解法&#xff1a; Step 1&#xff0c;查看有哪些目录占用了大量…

servlet初体验之环境搭建!!!

我们需要用到tomcat服务器&#xff0c;咩有下载的小伙伴看过来&#xff1a;如何正确下载tomcat&#xff1f;&#xff1f;&#xff1f;_明天更新的博客-CSDN博客 1. 创建普通的Java项目&#xff0c;并在项目中创建libs目录存放第三方的jar包。 建立普通项目 创建libs目录存放第三…

第三届计算机、物联网与控制工程国际学术会议(CITCE 2023)

第三届计算机、物联网与控制工程国际学术会议&#xff08;CITCE 2023) The 3rd International Conference on Computer, Internet of Things and Control Engineering&#xff08;CITCE 2023) 第三届计算机、物联网与控制工程国际学术会议&#xff08;CITCE 2023&#xff09;…

Git向远程仓库与推送以及拉取远程仓库

理解分布式版本控制系统 1.中央服务器 我们⽬前所说的所有内容&#xff08;⼯作区&#xff0c;暂存区&#xff0c;版本库等等&#xff09;&#xff0c;都是在本地也就是在你的笔记本或者计算机上。⽽我们的 Git 其实是分布式版本控制系统&#xff01;什么意思呢? 那我们多人…

Linux(实操篇二)

Linux实操篇 Linux(实操篇二)1. 常用基本命令1.3 时间日期类1.3.1 date显示当前时间1.3.2 显示非当前时间1.3.3 date设置系统时间1.3.4 cal查看日历 1.4 用户管理命令1.4.1 useradd添加新用户1.4.2 passwd设置用户密码1.4.3 id查看用户是否存在1.4.4 cat /etc/passwd 查看创建了…

C语言练习7(巩固提升)

C语言练习7 编程题 前言 “芳林新叶催陈叶&#xff0c;流水前波让后波。”改革开放40年来&#xff0c;我们以敢闯敢干的勇气和自我革新的担当&#xff0c;闯出了一条新路、好路&#xff0c;实现了从“赶上时代”到“引领时代”的伟大跨越。今天&#xff0c;我们要不忘初心、牢记…

2023有哪些更好用的网页制作工具

过去&#xff0c;专业人员使用HTMLL、CSS、Javascript等代码手动编写和构建网站。现在有越来越多的智能网页制作工具来帮助任何人实现零代码基础&#xff0c;随意建立和设计网站。在本文中&#xff0c;我们将向您介绍2023年流行的网页制作工具。我相信一旦选择了正确的网页制作…

Linux防火墙,可以防护什么?和常用的命令和各种日常使用(个人笔记)

文章目录 防火墙开启默认是否可以ping通http涉及端口的服务&#xff08;默认无法连接&#xff09;参考文档 防火墙开启默认是否可以ping通 可以的&#xff0c;并不会防护icmp类型的访问&#xff0c;但是会拦截http类型的访问查看ipifconfig查看防火墙状态 systemctl status fi…

JavaScript 手写题

基础手写 全排列&#xff08;力扣原题&#xff09; 要求以数组的形式返回字符串参数的所有排列组合。 注意&#xff1a; 字符串参数中的字符无重复且仅包含小写字母返回的排列组合数组不区分顺序const _permute string > {const result []const map new Map()const df…

美国访问学者签证有哪些要求?

近年来&#xff0c;越来越多的学者选择前往美国进行访问学者签证&#xff0c;以便深入研究、学术交流以及开展合作项目。美国访问学者签证是一个重要的工具&#xff0c;为学者们提供了在美国学术机构进行短期或长期学术活动的机会。下面知识人网将介绍一些申请美国访问学者签证…

LINQ详解(查询表达式)

什么是LINQ&#xff1f; LINQ(语言集成查询)是将查询功能直接集成到C#中。数据查询表示简单的字符串&#xff0c;在编译时不会进行类型检查和IntelliSense(代码补全辅助工具)支持。 在开发中&#xff0c;通常需要对不同类型的数据源了解不同的查询语句&#xff0c;如SQL数据库…

2023年IT服务行业研究报告

第一章 行业概况 1.1 定义 IT服务行业是一个广泛的术语&#xff0c;涵盖了所有提供技术支持和服务的公司。这些服务包括系统集成&#xff0c;云计算服务&#xff0c;软件和硬件支持&#xff0c;网络服务&#xff0c;咨询服务&#xff0c;以及一系列其他类型的技术服务。此外&…

MySQL中的Buffer Pool

一、概述 Buffer Pool是数据库的一个内存组件&#xff0c;里面缓存了磁盘上的真实数据&#xff0c;然后我们的Java系统对数据库执行的增删改操作&#xff0c;其实主要就是对这个内存数据结构中的缓存数据执行的。我们先来看一下下面的图&#xff0c;里面就画了数据库中的Buffer…

kubernetes deploy standalone mysql demo

kubernetes 集群内部署 单节点 mysql ansible all -m shell -a "mkdir -p /mnt/mysql/data"cat mysql-pv-pvc.yaml apiVersion: v1 kind: PersistentVolume metadata:name: mysql-pv-volumelabels:type: local spec:storageClassName: manualcapacity:storage: 5Gi…

集成学习:Bagging, Boosting,Stacking

目录 集成学习 一、bagging 二、boosting Bagging VS Boosting 1.1 集成学习是什么&#xff1f; Bagging Boosting Stacking 总结 集成学习 好比人做出一个决策时&#xff0c;会从不同方面&#xff0c;不同角度&#xff0c;不同层次去思考&#xff08;多个自我&am…

15-数据结构-二叉树的遍历,递归和非递归

简介&#xff1a; 本文主要是代码实现&#xff0c;二叉树遍历&#xff0c;递归和非递归&#xff08;用栈&#xff09;。主要为了好理解&#xff0c;直接在代码处&#xff0c;加了详细注释&#xff0c;方便复习和后期默写。主要了解其基本思想&#xff0c;为后期熟练应用…

光伏电站、变电站、等直流系统电参量测量仪器怎么选型

安科瑞虞佳豪 壹捌柒陆壹伍玖玖零玖叁 应用场景 工作拓扑图 功能 ①对电能参数进行采样计量和监测&#xff0c;逆变器或者能量管理系统&#xff08;EMS&#xff09;与之进行通讯&#xff0c;根据实时功率及累计电能实现防逆流、调节发电量、电池充放电等功能&#xff1b; ②…

uniapp使用sqlite 数据库

uniapp使用sqlite 数据库 傻瓜式使用方式&#xff0c;按步骤&#xff0c;即可使用。 1.开启sqlite 在项目中manifest.json该文件中配置 2.封装数据库的调用方法 const sqlName "zmyalh" //定义的数据库名称 const sqlPath "_doc/zmyalh.db" //定义数…