自制深度学习推理框架之表达式层的设计与实现

文章目录

    • 一、表达式Expression
    • 二、词法解析
      • 2.1 词法定义
      • 2.2 词法解析
    • 三、语法解析
      • 3.1 语法树的定义
      • 3.2 语法树构建
      • 3.3 语法树的转换(逆波兰式)
    • 四、表达式层
      • 4.1 ExpressionLayer和ExpressionParser类
      • 4.2 表达式层的注册
      • 4.3 表达式层的输入处理
      • 4.4 表达式层的计算过程
    • 五、计算图执行函数
      • 5.1 RuntimeGraph::ProbeNextLayer()
      • 5.2 RuntimeGraph::Forward()

一、表达式Expression

在 PNNX 中,Expression 类用于表示和处理计算图中的算子或节点的表达式。这些表达式通常涉及张量之间的运算、函数调用、以及其他数学或逻辑操作。在前面也说过,PNNX 会保留 PyTorch 所定义的表达式。

import torchdef foo(x, y):return torch.sqrt((2 * x + y) / 12)
imgmath.ptmath.pnnx

PyTorch中定义表达式在转换为PNNX之后,会保留表达式的整体结构,而不会被拆分成多个小的加减乘除算子。例如表达式sqrt(div(add(mul(@0,2),@1,1),12))不会被拆分为两个mul算子、一个add算子、一个div和sqrt算子,而是会生成一个表达式算子Expression

Expression层的主要功能:

  • 表达算子操作:Expression层可以表示各种算子操作,如加法、乘法、卷积等。这些算子通常对应于神经网络中的基本操作。

  • 表达式解析:Expressio层能够解析字符串形式的表达式,将其转换为内部的表达式树或图结构。这种结构可以进一步用于计算、求导或其他符号操作。

在PNNX的**表达式层(Expression Layer)**中,提供了一种计算表达式,该表达式能够在一定程度上折叠计算过程并消除中间变量。例如,在残差结构中的add操作在PNNX中就是一个表达式层。

下面是PNNX中对上述过程的计算表达式表示,其中的@0@1代表之前提到的计算数RuntimeOperand,用于表示计算表达式中的输入节点。

mul(@2, add(@0, @1)); 

尽管这个抽象表达式看起来比较简单,但实际上可能存在更为复杂的情况,例如以下的例子。因此,在这种情况下,需要一个强大而可靠的表达式解析和语法树构建功能

add(add(mul(@0, @1), mul(@2, add(add(add(@0, @2), @3), @4))), @5);

二、词法解析

2.1 词法定义

词法解析的目的是将expr:**add(@0, mul(@1, @2))**拆分为多个Token,拆分后的Token依次为:

  1. Identifier: add
  2. Left bracket: (
  3. Input number: @0
  4. Comma: ,
  5. Identifier: mul
  6. Left bracket: (
  7. Input number: @1
  8. Comma: ,
  9. Input number: @2
  10. Right bracket: )

Token类型定义如下:

enum class TokenType {TokenUnknown = -9,TokenInputNumber = -8,    // 数字TokenComma = -7,          // 逗号TokenAdd = -6,            // 加法TokenMul = -5,            // 乘法TokenLeftBracket = -4,    // 左括号TokenRightBracket = -3,   // 右括号
};

Token的定义如下,包括以下变量:

  1. Token类型,包括add(加法),mul(乘法),bracket(左右括号)等;
  2. Token在原句子中的开始和结束位置,即start_posend_pos

对于表达式add(@0, mul(@1, @2)),可以将它切分为多个Token,其中Token(add)的start_pos为0,end_pos为3。Token(left bracket)的start_pos为3,end_pos为4。Token(@0)的start_pos为4,end_pos为5,以此类推。

// 词语Token
struct Token {TokenType token_type = TokenType::TokenUnknown;int32_t start_pos = 0; // 词语开始的位置int32_t end_pos = 0;   // 词语结束的位置Token(TokenType token_type, int32_t start_pos, int32_t end_pos): token_type(token_type), start_pos(start_pos), end_pos(end_pos) {}
};

最后,在词法解析结束后,需要将这些 Token(词语)按照它们的出现顺序和层级关系组成一棵语法树

// 语法树的节点
struct TokenNode {int32_t num_index = -1;std::shared_ptr<TokenNode> left = nullptr;   // 语法树的左节点std::shared_ptr<TokenNode> right = nullptr;  // 语法树的右节点TokenNode(int32_t num_index, std::shared_ptr<TokenNode> left,std::shared_ptr<TokenNode> right);TokenNode() = default;
};

2.2 词法解析

  • 判断句子是否为空
CHECK(!statement_.empty()) << "The input statement is empty!";
  • 移除句子中的空格
statement_.erase(std::remove_if(statement_.begin(), statement_.end(),[](char c) { return std::isspace(c); }),statement_.end());
CHECK(!statement_.empty()) << "The input statement is empty!";

如果表达式层中有表达式为add(@0, @1),删除其中的空格后就会得到新的表达式add(@0,@1)

  • 逐个解析句子的字符
for (int32_t i = 0; i < statement_.size();) {char c = statement_.at(i);if (c == 'a') {  // 判断是否为addCHECK(i + 1 < statement_.size() && statement_.at(i + 1) == 'd')<< "Parse add token failed, illegal character: "<< statement_.at(i + 1);CHECK(i + 2 < statement_.size() && statement_.at(i + 2) == 'd')<< "Parse add token failed, illegal character: "<< statement_.at(i + 2);Token token(TokenType::TokenAdd, i, i + 3);tokens_.push_back(token);std::string token_operation =std::string(statement_.begin() + i, statement_.begin() + i + 3);token_strs_.push_back(token_operation);i = i + 3;} 

假设字符 c 表示当前的字符。如果 c 等于字符 ‘a’,根据词法规定,Token 中以 ‘a’ 开头的情况只有 add。因此,需要判断接下来的两个字符是否分别是 ‘d’ 和 ‘d’。如果不是,则报错。如果是的话,则初始化一个新的 Token,并保存其在表达式中的初始和结束位置。

如果某个字符 c 是 ‘m’,需要判断接下来的字符是否是 ‘u’ 和 ‘l’。如果不满足条件,则说明我们的表达式中出现了词汇表之外的单词(因为词汇表只允许以 ‘m’ 开头的单词是 “mul”)。如果满足条件,同样会初始化一个 Token 实例,并保存该单词的起始和结束位置,以及 Token 的类型。

else if (c == '@') {CHECK(i + 1 < statement_.size() && std::isdigit(statement_.at(i + 1)))<< "Parse number token failed, illegal character: " << c;int32_t j = i + 1;for (; j < statement_.size(); ++j) {if (!std::isdigit(statement_.at(j))) {break;}}Token token(TokenType::TokenInputNumber, i, j);CHECK(token.start_pos < token.end_pos);tokens_.push_back(token);std::string token_input_number = std::string(statement_.begin() + i, statement_.begin() + j);token_strs_.push_back(token_input_number);i = j;

如果第一个字符是 ‘@’,需要读取 ‘@’ 后面的所有数字,例如对于@31231,需要读取@符号之后的所有数字。如果紧跟在 ‘@’ 后面的字符不是数字,则报错。如果是数字,则将这些数字全部读取并组成一个单词(Token)。

else if (c == ',') {Token token(TokenType::TokenComma, i, i + 1);tokens_.push_back(token);std::string token_comma =std::string(statement_.begin() + i, statement_.begin() + i + 1);token_strs_.push_back(token_comma);i += 1;
}

如果第一个字符是’,'逗号,直接读取这个字符作为一个新的Token。

最后,在正确解析和创建这些 Token 后,将它们放入名为 tokens 的数组中,以便进行后续处理。

tokens_.push_back(token);

三、语法解析

3.1 语法树的定义

struct TokenNode {int32_t num_index = -1;std::shared_ptr<TokenNode> left = nullptr;std::shared_ptr<TokenNode> right = nullptr;TokenNode(int32_t num_index, std::shared_ptr<TokenNode> left, std::shared_ptr<TokenNode> right);TokenNode() = default;
};

在进行语法分析时,可以根据词法分析得到的 token 数组构建抽象语法树。抽象语法树是一个由二叉树组成的结构,每个节点都存储了操作符号或值,并通过左子节点和右子节点与其他节点连接

对于表达式 “add (@0, @1)”,当 num_index 等于 1 时,表示计算数为 @0;当 num_index 等于 2 时,表示计算数为 @1。若 num_index 为负数,则说明当前节点是一个计算节点,如 “mul” 或 “add” 等。

以下是一个简单的示例:

     add/   \@0     @1

在这个示例中,根节点是 “add”,左子节点是 “@0”,右子节点是 “@1”。这个抽象语法树表示了一个将 “@0” 和 “@1” 进行相加的表达式。通过将词法分析得到的 token 数组解析并构建抽象语法树,可以进一步对表达式进行语义分析和求值等操作。

3.2 语法树构建

语法解析即解析Token数组,构建抽象语法树,其过程是递归向下的,定义在Generate_函数中。

std::shared_ptr<TokenNode> ExpressionParser::Generate_(int32_t &index) {CHECK(index < this->tokens_.size());const auto current_token = this->tokens_.at(index);CHECK(current_token.token_type == TokenType::TokenInputNumber|| current_token.token_type == TokenType::TokenAdd || current_token.token_type == TokenType::TokenMul);

这个函数处理的对象是词法解析的Token(单词)数组,因为Generate_是一个递归函数,所以index参数指向Token数组中的当前处理位置.

current_token表示当前被处理的Token,它作为当前递归层的第一个Token,必须是以下类型之一。

TokenInputNumber = 0,
TokenAdd = 2,
TokenMul = 3,
  • **如果当前Token的类型是输入数字类型,那么会直接返回一个操作数Token作为叶子节点,不再进行下一层递归(如下)。**例如,在表达式add(@0, @1)中的@0和@1被归类为输入数字类型的Token,在解析到这两个Token时会直接创建并返回语法树节点TokenNode
if (current_token.token_type == TokenType::TokenInputNumber) {uint32_t start_pos = current_token.start_pos + 1;uint32_t end_pos = current_token.end_pos;CHECK(end_pos > start_pos);CHECK(end_pos <= this->statement_.length());const std::string &str_number =std::string(this->statement_.begin() + start_pos, this->statement_.begin() + end_pos);return std::make_shared<TokenNode>(std::stoi(str_number), nullptr, nullptr);}
  • 如果当前Token的类型是mul或者add,需要进行下一层递归来构建对应的左子节点和右子节点。

例如,在处理add(@1,@2)时,遇到add token之后,如下的第一行代码,需要做以下的两步:

  1. 首先判断是否存在左括号(left bracket)
  2. 然后继续向下递归以获取@1,如下的第14行到17行代码,但由于@1代表的是数字类型,递归后立即返回,如以上代码块中第一行对数字类型Token的处理。
else if (current_token.token_type == TokenType::TokenMul || current_token.token_type == TokenType::TokenAdd) {std::shared_ptr<TokenNode> current_node = std::make_shared<TokenNode>();current_node->num_index = -int(current_token.token_type);index += 1;CHECK(index < this->tokens_.size());// 判断add之后是否有( left bracketCHECK(this->tokens_.at(index).token_type == TokenType::TokenLeftBracket);index += 1;CHECK(index < this->tokens_.size());const auto left_token = this->tokens_.at(index);// 判断当前需要处理的left token是不是合法类型if (left_token.token_type == TokenType::TokenInputNumber|| left_token.token_type == TokenType::TokenAdd || left_token.token_type == TokenType::TokenMul) {// (之后进行向下递归得到@0current_node->left = Generate_(index);} else {LOG(FATAL) << "Unknown token type: " << int(left_token.token_type);}
}

在第17行当左子树递归构建完毕后,将它赋值到add节点的左子树上。对于表达式add(@0, @1),我们将左子树连接到current_nodeleft指针中,随后我们开始构建右子树。

add
ant 0
	index += 1; // 当前的index指向add(@1,@2)中的逗号CHECK(index < this->tokens_.size());// 判断是否是逗号CHECK(this->tokens_.at(index).token_type == TokenType::TokenComma);index += 1;CHECK(index < this->tokens_.size());// current_node->right = Generate_(index);构建右子树const auto right_token = this->tokens_.at(index);if (right_token.token_type == TokenType::TokenInputNumber|| right_token.token_type == TokenType::TokenAdd || right_token.token_type == TokenType::TokenMul) {current_node->right = Generate_(index);} else {LOG(FATAL) << "Unknown token type: " << int(left_token.token_type);}index += 1;CHECK(index < this->tokens_.size());CHECK(this->tokens_.at(index).token_type == TokenType::TokenRightBracket);return current_node;

随后需要判断@0之后是否存在comma token,如上代码中的第五行。在构建右子树的过程中,对于表达式add(@1,@2),当index指向逗号的位置时,首先需要判断是否存在逗号。接下来,我们开始构建右子树,在右子树的向下递归分析中,会得到@2作为一个叶子节点。

当右子树构建完成后,将该节点(即Generate_返回的TokenNode,此处为一个叶子节点,其数据为@1)放置于current_noderight指针中。

add
ant 0
ant 1

假设现在表达式层中的表达式是:add(@0,@1)。在词法解析模块中,这个表达式将被构建成一个单词(Token)数组,如以下:

  • add
  • (
  • @0
  • ,
  • @1
  • )

词法解析结束之后,这个表达式将被传递到语法解析模块中,用于构建抽象语法树。Generate_函数首先检查Token数组中的当前单词(Token)是否是以下类型的一种:

CHECK(index < this->tokens_.size());
const auto current_token = this->tokens_.at(index);
CHECK(current_token.token_type == TokenType::TokenInputNumber ||current_token.token_type == TokenType::TokenAdd ||current_token.token_type == TokenType::TokenMul);

当前的索引为0,表示正在处理Token数组中的"add"单词。针对这个输入,需要判断其后是否是"左括号"来确定其合法性。如果是合法的(add单词之后总存在括号),我们将构建一个左子树。因为对于一个add调用,它的后面总是跟着一个左括号"(",如下方代码的第8行。

else if (current_token.token_type == TokenType::TokenMul ||current_token.token_type == TokenType::TokenAdd) {std::shared_ptr<TokenNode> current_node = std::make_shared<TokenNode>();current_node->num_index = int(current_token.token_type);index += 1;CHECK(index < this->tokens_.size()) << "Missing left bracket!";CHECK(this->tokens_.at(index).token_type == TokenType::TokenLeftBracket);index += 1;CHECK(index < this->tokens_.size()) << "Missing correspond left token!";const auto left_token = this->tokens_.at(index);

在以上代码的第8行中,我们对’add’之后的一个Token进行判断,如果是左括号则匹配成功,开始匹配括号内的元素。对于输入add(@0, @1),在第10行中,当对索引进行+1操作后,我们需要开始解析括号内左侧的元素left_token.

随后开始递归构建表达式的左子树:

if (left_token.token_type == TokenType::TokenInputNumber ||left_token.token_type == TokenType::TokenAdd ||left_token.token_type == TokenType::TokenMul) {current_node->left = Generate_(index);
} 

对于当前的例子,当前索引(index)指向的单词是@0。在这种情况下,由于索引指向的位置是一个输入数字@0(TokenType::TokenInputNumber)的类型,所以该节点进入递归调用后将直接返回。

根据前文给出的例子,add的左子树构建完毕后,下一步需要判断中add(@0,@1)的@0之后是否存在逗号

add
ant 0
index += 1;
CHECK(index < this->tokens_.size()) << "Missing comma!";
CHECK(this->tokens_.at(index).token_type == TokenType::TokenComma);

接下来,要为如上的二叉树构建右子树:

const auto right_token = this->tokens_.at(index);
if (right_token.token_type == TokenType::TokenInputNumber ||right_token.token_type == TokenType::TokenAdd ||right_token.token_type == TokenType::TokenMul) {current_node->right = Generate_(index);
} else {LOG(FATAL) << "Unknown token type: " << int(right_token.token_type);
}

同样,由于当前索引(index)指向的位置是@1,它是一个输入数据类型,所以该节点在进入递归调用后将直接返回,并成为add节点的右子树,如下方代码所示。

add
ant 0
ant 1
std::shared_ptr<TokenNode> ExpressionParser::Generate_(int32_t &index) {CHECK(index < this->tokens_.size());......如果是Input Number就直接返回if (current_token.token_type == TokenType::TokenInputNumber) {uint32_t start_pos = current_token.start_pos + 1;uint32_t end_pos = current_token.end_pos;CHECK(end_pos > start_pos);CHECK(end_pos <= this->statement_.length());const std::string &str_number =std::string(this->statement_.begin() + start_pos, this->statement_.begin() + end_pos);return std::make_shared<TokenNode>(std::stoi(str_number), nullptr, nullptr);}
}

如果现在有一个表达式add(mul(@0,@1),@2),在词法解析中,它将被分割成以下的数个单词组成的数组:

  1. add
  2. left bracket
  3. mul
  4. left bracket
  5. @0
  6. comma
  7. @1
  8. right bracket
  9. comma
  10. @2
  11. right bracket

当以上的数组被输入到语法解析中后,index的值等于0。随后判断index指向位置的单词类型是否符合要求。

CHECK(current_token.token_type == TokenType::TokenInputNumber ||current_token.token_type == TokenType::TokenAdd ||current_token.token_type == TokenType::TokenMul);

如果该表达式的第一个单词是"add",那么我们就像之前的例子一样,将它作为二叉树的左子树进行构建。

if (left_token.token_type == TokenType::TokenInputNumber ||left_token.token_type == TokenType::TokenAdd ||left_token.token_type == TokenType::TokenMul) {current_node->left = Generate_(index);

已知表达式为add(mul(@0,@1),@2),在处理完这个表达式的左括号之后,当前指向的标记是"mul",它不属于输入参数类型。因此,在调用Generate_函数时,我们将对"mul"子表达式进行递归分析。

对"mul"子表达式解析的方式和对add(@0,@1)解析的方式相同,"mul"子表达式的分析结果如下图所示:

mul
ant 0
ant 1

在子表达式的解析完成并返回后,我们将这颗子树插入到当前节点的左指针上**(current_node->left = Generate_(index))**

add
mul
ant 0
ant 1

随后我们开始解析add(mul(@0,@1),@2)表达式中@2以及其之后的部分作为add的右子树。

if (right_token.token_type == TokenType::TokenInputNumber ||right_token.token_type == TokenType::TokenAdd ||right_token.token_type == TokenType::TokenMul) {current_node->right = Generate_(index);
} else {LOG(FATAL) << "Unknown token type: " << int(right_token.token_type);
}

在第4行调用Generate_之后,由于@2是一个输入数类型,不再进行递归分析,所以它将被直接返回并赋值给current_node->right

add
mul
ant 0
ant 1
ant 2
std::shared_ptr<TokenNode> ExpressionParser::Generate_(int32_t &index) {CHECK(index < this->tokens_.size());......如果是Input Number就直接返回if (current_token.token_type == TokenType::TokenInputNumber) {uint32_t start_pos = current_token.start_pos + 1;uint32_t end_pos = current_token.end_pos;CHECK(end_pos > start_pos);CHECK(end_pos <= this->statement_.length());const std::string &str_number =std::string(this->statement_.begin() + start_pos, this->statement_.begin() + end_pos);return std::make_shared<TokenNode>(std::stoi(str_number), nullptr, nullptr);}
}

3.3 语法树的转换(逆波兰式)

以一个简单的例子来说明,对于计算式add(@0,@1),首先遇到的节点是add,但在遇到add时缺少进行计算所需的具体数据@0@1

**因此,还需要进行逆波兰转换,将操作数放在前面,计算放在后面。**该转换的实现非常简单,只需对原有的二叉树进行后续遍历即可:

void ReversePolish(const std::shared_ptr<TokenNode> &root_node,std::vector<std::shared_ptr<TokenNode>> &reverse_polish) {if (root_node != nullptr) {ReversePolish(root_node->left, reverse_polish);ReversePolish(root_node->right, reverse_polish);reverse_polish.push_back(root_node);}
}

逆波兰式化后的表达如下:

对于 add (@0,@1),逆波兰式为:@0,@1,add

对于 add(mul(@0,@1),@2),逆波兰式为:@0,@1,mul,@2,add

通过逆波兰转换,可以将原式转换为计算式的输入数放在前面,操作符号放在后面的形式。逆波兰式的特点是消除了括号的需求,使得计算顺序更加清晰和直观。

小节:

经过这样的转换,可以确保在每次遇到计算节点时所需的操作数已经准备就绪。

  1. 首先,传入一个表达式字符串,例如add(mul(@0,@1),@2)

  2. 接下来,对add(mul(@0,@1),@2)进行词法分析,将其拆分为多个tokens,在拆分过程中需要进行词法校验。

  3. 然后,根据已知的tokens数组,通过递归向下遍历进行语法分析,从而得到相应的计算二叉树。计算二叉树的各个节点可以是add、mul或者@0、@1等。

  4. 最后,对计算二叉树进行逆波兰变换(后续遍历),得到的逆波兰式如下:@0、@1、mul、@2、add。

image-20240822205215268

四、表达式层

4.1 ExpressionLayer和ExpressionParser类

class ExpressionLayer : public NonParamLayer {public:explicit ExpressionLayer( std::string statement);InferStatus Forward(const std::vector<std::shared_ptr<Tensor<float>>>& inputs,std::vector<std::shared_ptr<Tensor<float>>>& outputs) override;static ParseParameterAttrStatus GetInstance(const std::shared_ptr<RuntimeOperator>& op,std::shared_ptr<Layer>& expression_layer);private:std::string statement_;std::unique_ptr<ExpressionParser> parser_;
};

在表达式层面有两个类内变量:

  • statement_:需要解析的表达式
  • parser_: 解析器,用于词法分析和语法分析的 ExpressionParser,解析表达式statement_

ExpressionParser类定义在include/parser/parse_expression.hpp

class ExpressionParser {public:explicit ExpressionParser(std::string statement): statement_(std::move(statement)) {}/*** 词法分析* @param retokenize 是否需要重新进行语法分析*/void Tokenizer(bool retokenize = false);/*** 语法分析* @return 生成的语法树*/std::vector<std::shared_ptr<TokenNode>> Generate();......private:std::shared_ptr<TokenNode> Generate_(int32_t& index);// 被分割的词语数组std::vector<Token> tokens_;// 被分割的字符串数组std::vector<std::string> token_strs_;// 待分割的表达式std::string statement_;
};

主要负责将输入的表达式字符串转换成一种中间表示形式,通常是逆波兰表示法或抽象语法树(AST)。这个类通常包括:

  • Tokenizer: 将表达式字符串拆分成 token。

  • Generate: 解析这些 tokens,构建树结构或其他中间表示形式。

  • 生成逆波兰表示法或抽象语法树: 供其他部分使用。

ExpressionParser中,,Tokenizer用于将表达式分割为多个单词,例如将add(@0,@1)分割为以下的几个单词:

  1. add
  2. left bracket
  3. input number(@0)
  4. comma

ExpressionParser生成语法树Generate函数:

std::vector<std::shared_ptr<TokenNode>> ExpressionParser::Generate() {if (this->tokens_.empty()) {this->Tokenizer(true);}int index = 0;// 构建语法树std::shared_ptr<TokenNode> root = Generate_(index);CHECK(root != nullptr);CHECK(index == tokens_.size() - 1);// 转逆波兰式,之后转移到expression中std::vector<std::shared_ptr<TokenNode>> reverse_polish;ReversePolish(root, reverse_polish);return reverse_polish;
}

在以上的代码中,首先调用 Generate_ 方法。该方法通过对 tokens 数组进行语法分析,生成一棵抽象语法树。然后对这棵抽象语法树进行逆波兰排序,得到最终的执行序列。

4.2 表达式层的注册

LayerRegistererWrapper用于将 GetInstance方法与 “pnnx.Expression” 字符串进行绑定,使得运行时系统能够通过 “pnnx.Expression” 创建 ExpressionLayer实例,将其注册到推理框架中,推理框架才能够使用自定义的表达式层进行计算

LayerRegistererWrapper kExpressionGetInstance("pnnx.Expression",ExpressionLayer::GetInstance);

初始化过程如下:

ParseParameterAttrStatus ExpressionLayer::GetInstance(const std::shared_ptr<RuntimeOperator>& op,std::shared_ptr<Layer>& expression_layer) {// 检查传入的 RuntimeOperator 对象指针是否为空CHECK(op != nullptr) << "Expression operator is nullptr";// 获取 RuntimeOperator 的参数集合const auto& params = op->params;// 确认参数中包含 "expr" 键if (params.find("expr") == params.end()) {return ParseParameterAttrStatus::kParameterMissingExpr;}// 将 "expr" 参数转换为 RuntimeParameterString 类型auto statement_param = std::dynamic_pointer_cast<RuntimeParameterString>(params.at("expr"));// 如果转换失败,则记录错误并返回参数缺失状态if (statement_param == nullptr) {LOG(ERROR) << "Can not find the expression parameter";return ParseParameterAttrStatus::kParameterMissingExpr;}// 检查参数的类型是否为字符串if (statement_param->type != RuntimeParameterType::kParameterString) {LOG(ERROR) << "The expression parameter should be of type string";return ParseParameterAttrStatus::kParameterMissingExpr;}// 使用提取到的表达式字符串创建 ExpressionLayer 实例expression_layer = std::make_shared<ExpressionLayer>(statement_param->value);// 返回成功状态return ParseParameterAttrStatus::kParameterAttrParseSuccess;
}

在以上的代码中 statement_param是从PNNX中提取表达式字符串expr,然后使用该字符串来实例化算子。

4.3 表达式层的输入处理

在Expression Layer的Forward函数输入中,也就是在下面这个数组中,多个输入依次排布:

const std::vector<std::shared_ptr<Tensor<float>>>& inputs

如果batch_size的大小为4,那么input1中的元素数量为4,input2的元素数量也为4。input1中的数据都来源于同一批次的操作数1(operand 1),input2中的数据都来源于同一批次的操作数2(operand 2)。其中,input1中的4(batch size = 4)个元素都是来自于操作数1,而input2中的4(batch size = 4)个元素都是来自于操作数2,它们在inputs数组参数中依次排列,如下图所示:

image-20240830194408863
  • 对于两个输入操作数

已知有如上的数据存储排布, 下面将讨论如何根据现有的数据完成add(@0,@1)计算. 可以看到每一次计算的时候, 都以此从input1input2中取得一个数据进行加法操作, 并存放在对应的输出位置.

image-20240830194432203
  • 对于三个输入操作数

下图的例子展示了对于三个输入,mul(add(@0,@1),@2)的情况:

image-20240830194454919

每次计算的时候依次从input1, input2input3中取出数据, 并作出相应的运算, 并将结果数据存放于对应的output中。我们简单说明一下:

  1. o u t p u t 1 = ( i n p u t 1 + i n p u t 5 ) × i n p u t 9 output_1=(input_1+input_5)\times input_9 output1=(input1+input5)×input9,对于第一个输出数据,我们先从取出第一组输入(@0)中第一个输入数据 i n p u t 1 input_1 input1,再从第二组输入(@1)中取得第一个输入数据 i n p u t 5 input_5 input5,最后再从第三组输入(@2)中取得第一个输入数据 i n p u t 9 input_9 input9.
  2. o u t p u t 2 = ( i n p u t 2 + i n p u t 6 ) × i n p u t 10 output_2=(input_2+input_6)\times input_{10} output2=(input2+input6)×input10,对于第一个输出数据,我们先从取出第一组输入(@0)中第一个输入数据 i n p u t 2 input_2 input2,再从第二组输入(@1)中取得第一个输入数据 i n p u t 6 input_6 input6,最后再从第三组输入(@2)中取得第一个输入数据 i n p u t 10 input_{10} input10.
  3. o u t p u t 3 = ( i n p u t 3 + i n p u t 7 ) × i n p u t 11 output_3=(input_3+input_7)\times input_{11} output3=(input3+input7)×input11,对于第一个输出数据,我们先从取出第一组输入(@0)中第一个输入数据 i n p u t 3 input_3 input3,再从第二组输入(@1)中取得第一个输入数据 i n p u t 7 input_7 input7,最后再从第三组输入(@2)中取得第一个输入数据 i n p u t 11 input_{11} input11.
  4. o u t p u t 4 output_4 output4同理。

4.4 表达式层的计算过程

表达式层同样继承于算子的父类Layer,并重写其中的Forward方法。在Forward方法中,定义了表达式层的计算逻辑,即数据结构中逆波兰式的计算。

假设现在有一个计算式为add(mul(@0,@1),@2),通过抽象语法树构建和逆波兰转换,得到了以下序列:@0、@1、mul、@2、add。

image-20240822183251160

在Forward函数中,定义了一个栈式计算模块,并维护了一个输入数栈。输入数栈是一个先进后出的数据结构,用于存放表达式中的输入数。

对于给定的表达式,例如add(mul(@0,@1),@2) --> (@0 @1 mul @2 add),将前两个输入数依次压入输入数栈中。在序列中的下一个 节点是mul,它的作用是将两个输入数相乘,因此,需要从输入数栈中顺序地弹出两个输入数进行相乘操作。通过该操作,得到一个中间结果@tmp1,需要将这个中间结果存放到输入数栈中,以便供后续步骤处理。

在序列中的下一个节点是@2,是一个输入操作数,将它存放到输入数栈中。在序列中的最后一个节点是add,它是一个加法节点,需要两个输入数据。因此,它会将栈中的@2@tmp1全部弹出,进行加法操作,得到整个计算序列最后的结果。

代码实现如下:

if (token_node->num_index >= 0) {// process operatoruint32_t start_pos = token_node->num_index * batch_size;std::vector<std::shared_ptr<Tensor<float>>> input_token_nodes;for (uint32_t i = 0; i < batch_size; ++i) {CHECK(i + start_pos < inputs.size())<< "The " << i<< "th operand doesn't have appropriate number of tensors";input_token_nodes.push_back(inputs.at(i + start_pos));}op_stack.push(input_token_nodes);
} 

根据输入的逆波兰式@0,@1,add,遇到的第一个节点是操作数是@0,所以栈op_stack内的内存布局如下:

image-20240822204204039

当按顺序遇到第二个节点(op)时,也就是操作数@1时,将从inputs中读取操作数并将其存放到input_token_nodes中。然后,将input_token_nodes这一批次的数据放入栈中。

运算符的代码处理

const int32_t op = token_node->num_index;
...
std::vector<std::shared_ptr<Tensor<float>>> input_node1 = op_stack.top();
...
...
op_stack.pop();std::vector<std::shared_ptr<Tensor<float>>> input_node2 = op_stack.top();
CHECK(input_node2.size() == batch_size)<< "The second operand doesn't have appropriate number of tensors, ""which need "<< batch_size;
op_stack.pop();

当节点(op)类型为操作符号时(也就是num_index小于0的时候),首先从栈(op_stack)中弹出两个批次的操作数。对于给定情况,input_node1存放的是input1input4,而input_node2存放的是input5input8

std::vector<std::shared_ptr<Tensor<float>>> output_token_nodes(batch_size);for (uint32_t i = 0; i < batch_size; ++i) {// do executionif (op == int(TokenType::TokenAdd)) {output_token_nodes.at(i) =TensorElementAdd(input_node1.at(i), input_node2.at(i));} else if (op == int(TokenType::TokenMul)) {output_token_nodes.at(i) =TensorElementMultiply(input_node1.at(i), input_node2.at(i));} else {LOG(FATAL) << "Unknown operator type: " << op;}
}
op_stack.push(output_token_nodes);

在获取大小为batch_sizeinput_node1input_node2后,流程将在for(int i = 0...batch_size)循环中对这两个输入进行两两操作(input1+input5, input2+input6),具体的操作类型定义于当前的op中。最后,我们将计算得到的结果放入输入数栈op_stack中。

五、计算图执行函数

在完成计算图所有算子的注册后,还需要对计算图中的算子按照拓扑排序依次计算得到推理结果。这里主要涉及RuntimeGraph类中的ProbeNextLayer()函数和Forward()函数。

5.1 RuntimeGraph::ProbeNextLayer()

ProbeNextLayer函数实现了一个关键的图结构操作:确保在执行前向计算时,当前层的输出数据能够正确地传递到依赖该输出的下一层节点

void RuntimeGraph::ProbeNextLayer(const std::shared_ptr<RuntimeOperator> &current_op,  // 表示当前的算子// 当前算子的输出,需要把当前算子的输出赋值到它后继节点的输入中const std::vector<std::shared_ptr<Tensor<float>>> &layer_output_datas) {// 当前节点的后继节点集合const auto &next_ops = current_op->output_operators;// 对所有后继节点进行遍历for (const auto &[_, next_rt_operator] : next_ops) {// 得到后继节点的输入操作数集合const auto &next_input_operands = next_rt_operator->input_operands;// 检查当前操作符的输出是否作为后继节点的输入if (next_input_operands.find(current_op->name) != next_input_operands.end()) {/*** next_input_operands:* {*    输入1 -- current_op.name: current_op对应的输出空间*    输入2 -- other_op.name: other_op对应的输出空间* }*/// 得到后继节点的关于 current_op 输出的输入张量空间 next_input_datasstd::vector<std::shared_ptr<ftensor>> &next_input_datas =next_input_operands.at(current_op->name)->datas;// 确保后继节点的输入空间大小与当前操作符的输出空间大小一致CHECK(next_input_datas.size() == layer_output_datas.size());// 将当前操作符的输出数据赋值到后继节点的输入数据空间中for (int i = 0; i < next_input_datas.size(); ++i) {next_input_datas.at(i) = layer_output_datas.at(i);}}}
}

ProbeNextLayer函数中,对当前节点的所有后继节点进行依次的遍历,并将当前节点的输出赋值给后继节点的输入。该过程的三个主要步骤:

  1. 遍历后继节点

    • 函数首先获取当前节点 (current_op) 的所有后继节点 (next_ops)。
    • 然后,遍历每一个后继节点 (next_rt_operator)。
  2. 识别输入输出关系

    • 对于每一个后继节点,函数会识别它的输入张量空间 (next_input_operands)。
    • 具体来说,它会寻找与当前节点 (current_op) 对应的输入张量空间,因为每个后继节点可能从多个前驱节点获取输入
  3. 数据赋值

    • 一旦找到对应的输入张量空间 (next_input_datas),函数会将当前节点 (current_op) 的输出数据 (layer_output_datas) 赋值给后继节点的相应输入张量。

5.2 RuntimeGraph::Forward()

在前面已经详细讲解了模型中所有算子的执行顺序排序,排序的顺序也称为拓扑序。通过对模型中的所有算子进行拓扑排序,可以得到一个算子序列 topo_operators_。因此在执行函数 Forward 时,只需按照顺序依次执行算子序列 (topo_operators) 中每个算子的 Forward 方法即可,因为每个算子(Layer)的具体计算逻辑都实现在它重载的 Forward 方法中。

std::vector<std::shared_ptr<Tensor<float>>> RuntimeGraph::Forward(const std::vector<std::shared_ptr<Tensor<float>>>& inputs, bool debug) {......for (const auto& op : topo_operators_) { op->has_forward = false; } // 重置每个操作符的前向计算标记// 遍历按拓扑排序的操作符队列,逐一执行前向计算for (const auto& current_op : topo_operators_) {// 如果当前操作符是输入节点(pnnx.Input)if (current_op->type == "pnnx.Input") {current_op->has_forward = true;ProbeNextLayer(current_op, inputs);  // 将输入数据传递到下一层}else if (current_op->type == "pnnx.Output") { // 如果当前操作符是输出节点(pnnx.Output)current_op->has_forward = true;// 确保输出节点只有一个输入操作数CHECK(current_op->input_operands_seq.size() == 1);// 将输入操作数设置为输出操作数current_op->output_operands = current_op->input_operands_seq.front();}else {  // 如果当前操作符是普通的计算层// 执行当前操作符的前向计算InferStatus status = current_op->layer->Forward();// 检查前向计算是否成功CHECK(status == InferStatus::kInferSuccess) << current_op->layer->layer_name()<< " layer forward failed, error code: " << int(status);current_op->has_forward = true;// 将计算结果传递到下一层ProbeNextLayer(current_op, current_op->output_operands->datas);}}.......// 获取最终的输出操作数,并返回其数据if (operators_maps_.find(output_name_) != operators_maps_.end()) {const auto& output_op = operators_maps_.at(output_name_);CHECK(output_op->output_operands != nullptr)<< "Output from " << output_op->name << " is empty";const auto& output_operand = output_op->output_operands;return output_operand->datas;} else {LOG(FATAL) << "Can not find the output operator " << output_name_;return std::vector<std::shared_ptr<Tensor<float>>>{};}
}

RuntimeGraph::Forward函数的主要任务是根据拓扑排序顺序依次执行图中的所有操作符的前向计算,最终生成并返回计算结果。

在算子执行的阶段,依次取出topo_operators数组中的所有算子,并调用它们的Forward方法。在遍历执行的循环内,可以将算子分为三类:输入类型输出类型普通算子。普通算子包括卷积、池化、线性激活等算子。

  • pnnx.Input输入类型算子

对于执行输入类型算子,由于它是输入节点,无需调用 Forward 方法进行算子的计算,只需要将输入节点的输入拷贝到下一级节点即可完成操作。

if (current_op->type == "pnnx.Input"){current_op->has_forward = true;ProbeNextLayer(current_op, inputs);
}
  • pnnx.Output输出类型算子

通过设置 current_op->has_forward = true,标记这个输出操作符已经完成前向计算。从 current_op 的 input_operands_seq 中获取输入操作数,并将其设置为 current_op->output_operands。这里假设输出操作符只有一个输入。

else if (current_op->type == "pnnx.Output") {current_op->has_forward = true;CHECK(current_op->input_operands_seq.size() == 1);current_op->output_operands = current_op->input_operands_seq.front();
}
  • 普通类型算子
else {InferStatus status = current_op->layer->Forward();CHECK(status == InferStatus::kInferSuccess)<< current_op->layer->layer_name()<< " layer forward failed, error code: " << int(status);current_op->has_forward = true;ProbeNextLayer(current_op, current_op->output_operands->datas);
}// Forward方法完成了计算,并且把输出存放在current_op的输出张量空间中
// ProbeNextLayer要把current_op的输出张量传递到下一级、后继节点的输入空间当中。

执行普通类型的算子和执行输入类型的算子流程大致相同。但是在执行普通类型的算子时,先调用当前算子的Forward重载函数,Forward函数中会根据各类算子自定义的计算逻辑对输入张量完成计算,例如卷积算子就对输入张量计算卷积,池化算子在Forward函数就对输入张量计算池化。在当前算子执行完毕后,通过ProbeNextLayer函数将当前节点的输出赋值给后继节点的输入张量,具体流程同上所述。

在所有算子执行完毕后,需要在所有算子的合集中查询与output_name名称对应的算子(output_op),然后获取该算子对应的输出张量,从执行图中提取最终的输出数据,即整个模型的推理输出output_operand。

// 在build阶段已经指定了output_name_,找到output_name_对应的算子
if (operators_maps_.find(output_name_) != operators_maps_.end()) {const auto& output_op = operators_maps_.at(output_name_);CHECK(output_op->output_operands != nullptr)<< "Output from" << output_op->name << " is empty";const auto& output_operand = output_op->output_operands;// 返回输出算子对应的output张量空间return output_operand->datas;
}

首先检查 operators_maps_中是否存在名为 output_name_ 的节点(输出节点)。如果输出操作数有效,代码从输出操作数中提取出数据 (output_operand->datas) 并返回。这个数据就是整个计算图最终的输出结果。最后output_operand会再作为Graph.Forward函数调用输出(模型的预测输出)的返回

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/412580.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

插入排序的动画展示与实现

排序学习思路&#xff1a;先实现单趟逻辑&#xff0c;在实现整体逻辑&#xff1b;先解决普遍情况&#xff0c;再解决特殊情况。 什么是插入排序 回忆下自己玩扑克牌的时候是怎么把手上的牌理顺的吧&#xff01;其实那就是插入排序&#xff0c;从左边往右边&#xff0c;把一张张…

Profinet 从站转 EtherNet/IP 从站网关

产品用途 本产品是 PN(Profinet) 和 EtherNet/IP 网关&#xff0c;使用数据映射方式工作。 本产品在 PN 侧作为 PN IO 从站&#xff0c;接 PN 主站设备&#xff0c;比如西门子 PLC 等&#xff1b;在EtherNet/IP 侧做为 EtherNet/IP 从站&#xff0c;接 EtherNet…

C++:继承用法详解~

在学完C的类和对象&#xff0c;并掌握了类的核心语法与基本用法之后&#xff1b;我们就得去学习一下继承的语法&#xff0c;与继承的用法。简单概括一下&#xff0c;继承是C中一种代码复用的手段&#xff0c;它允许我们&#xff0c;对已有的类&#xff0c;增添新的成员函数或变…

opencv实战项目十六:kmeans图像颜色聚类:

文章目录 前言K-means介绍效果 前言 在数字化时代&#xff0c;图像处理技术已成为计算机视觉领域的重要组成部分。其中&#xff0c;图像颜色聚类作为一项关键技术在众多应用场景中发挥着重要作用&#xff0c;如图像分割、物体识别、色彩调整等。K-means算法作为一种经典的聚类…

【云游戏】点量云流赋能大型游戏新体验

点量小刘发现近期国产化大型3A游戏《黑神话&#xff1a;悟空》的发售&#xff0c;可谓是赢得了一波好评。从场景内容来说深厚的文化底蕴支撑和高质量精美的特效及画面制作令人眼前一亮&#xff0c;作为备受瞩目的一款游戏&#xff0c;从技术层面来说&#xff0c;该游戏也离不开…

【DSP+FPGA】基于DSP+FPGA XC7K325T与TMS320C6678的通用信号处理平台

DSP FPGA 协同处理架构板载 1 个TMS320C6678 多核DSP处理节点板载 1 片 XC7K325T FPGA处理节点板载 1 个FMC 接口板载4路SFP光纤接口FPGA 与 DSP 之间采用高速Rapid IO互联 基于FPGA与DSP协同处理架构的通用高性能实时信号处理平台&#xff0c;该平台采用1片TI的KeyStone系列多…

制造企业如何启用BI工具,并构建自助式BI业务模式?

在制造业的数字化转型浪潮中&#xff0c;商业智能BI工具正逐渐成为推动企业增长的“加速引擎”。随着数据量的爆炸性增长&#xff0c;如何高效地分析和利用数据&#xff0c;已成为制造业提升竞争力的关键。本文将基于BI工具在制造业中的优势&#xff0c;深入探讨一种创新的BI分…

[Meachines] [Insane] Bankrobber XSS-MDOG+SQLI+XSRF+Local-RCE+Bankv2转账模拟应用缓冲区溢出

信息收集 IP AddressOpening Ports10.10.10.154TCP:80&#xff0c;443&#xff0c;445&#xff0c;3306 $ nmap -p- 10.10.10.154 --min-rate 1000 -sC -sV -Pn PORT STATE SERVICE VERSION 80/tcp open http …

jenkins安装k8s插件发布服务

1、安装k8s插件 登录 Jenkins&#xff0c;系统管理→ 插件管理 → 搜索 kubernetes&#xff0c;选择第二个 Kubernetes&#xff0c;点击 安装&#xff0c;安装完成后重启 Jenkins 。 2、对接k8s集群、申请k8s凭据 因为 Jenkins 服务器在 kubernetes 集群之外&#xff0c;所以…

JVM垃圾回收算法:标记-清除算法 、复制算法、 标记-整理算法、 分代收集算法

文章目录 引言I 标记回收算法(Mark-Sweep)算法不足II 复制算法(Copying)III 标记整理算法(Mark-Compact)IV 分代收集(以上三种算法的集合体)内存划分新生代算法:Minor GC老年代算法V 查看JVM堆分配引言 垃圾回收(Garbage Collection,GC) Java支持内存动态分配、…

机器学习/数据分析案例---糖尿病预测

&#x1f368; 本文为&#x1f517;365天深度学习训练营 中的学习记录博客&#x1f356; 原作者&#xff1a;K同学啊 前言 这是一篇数据分析/机器学习很好的入门案例&#xff0c;对糖尿病的影响进行预测和分析通过随机森林预测&#xff0c;平均准确率和召回率都不错不足&#x…

Pytorch实现多层LSTM模型,并增加emdedding、Dropout、权重共享等优化

简述 本文是 Pytorch封装简单RNN模型&#xff0c;进行中文训练及文本预测 一文的延申&#xff0c;主要做以下改动&#xff1a; 1.将nn.RNN替换为nn.LSTM&#xff0c;并设置多层LSTM&#xff1a; 既然使用pytorch了&#xff0c;自然不需要手动实现多层&#xff0c;注意nn.RNN…

Threejs之OrbitControls轨道控制器

本文目录 前言一、Orbitcontrols&#xff08;轨道控制器&#xff09;1.1 基础使用1.2 代码演示 二、效果展示 前言 Orbitcontrols&#xff08;轨道控制器&#xff09;可以使得相机围绕目标进行轨道运动。 一、Orbitcontrols&#xff08;轨道控制器&#xff09; 1.1 基础使用 C…

【Python 千题 —— 基础篇】身份证隐藏的信息

Python 千题持续更新中 …… 脑图地址 👉:⭐https://twilight-fanyi.gitee.io/mind-map/Python千题.html⭐ 题目描述 题目描述 在一个用户信息管理系统中,你需要处理和验证用户提供的身份证号。编写一个程序来从用户信息字符串中提取和验证身份证号,并提供相应的处理方式…

图论----最小生成树讲解与相关题解

目前已更新系列 当前--图论----最小生成树讲解与相关题解 滑动窗口系列算法总结与题解一 算法系列----并查集总结于相关题解 图论---dfs系列 差分与前缀和总结与对应题解&#xff08;之前笔试真的很爱考&#xff09; 数论---质数判断、质因子分解、质数筛&#xff08;埃氏…

在 Cilium CNI 集群上运行 vCluster 虚拟集群

上周在 KubeCon China 2024 大会上&#xff0c;我和社区伙伴们作为志愿者在 Cilium 项目展台与用户交流。有位用户询问 Cilium 是否能与 vCluster 集成&#xff0c;当时未能给出明确答复&#xff0c;特地回来后进行了测试。 答案是&#xff1a;在最新的 vCluster v0.20 中容器…

【Python篇】Python 类和对象:详细讲解(上篇)

文章目录 Python 类和对象&#xff1a;详细讲解1. 什么是类&#xff08;Class&#xff09;类的定义 2. 什么是对象&#xff08;Object&#xff09;创建对象 3. 属性和方法属性&#xff08;Attributes&#xff09;方法&#xff08;Methods&#xff09;在类中定义属性和方法使用对…

重生奇迹MU 小清新职业智弓MM

游戏中有一种令人迷醉的职业——智弓MM&#xff0c;她们以高超的射箭技能闻名于世。本文将为您介绍这个悠闲的小清新职业&#xff0c;在游戏中的特点以及如何成为一名出色的智弓MM。跟随我们一起探索这个奇妙而神秘的职业吧&#xff01; 悠闲的游戏节奏是游戏的初衷之一&#…

52 mysql 启动过程中常见的相关报错信息

前言 我们这里主要是看一下 service mysql start, service mysql stop 的过程中的一些常见的错误问题 这些 也是之前经常碰到, 但是 每次都是 去搜索, 尝试 1, 2, 3, 4 去解决问题 但是 从来未曾思考过 这个问题到底是 怎么造成的 The server quit without updating PID fil…

【设计模式】创建型模式——抽象工厂模式

抽象工厂模式 1. 模式定义2. 模式结构3. 实现3.1 实现抽象产品接口3.2 定义具体产品3.3 定义抽象工厂接口3.4 定义具体工厂3.5 客户端代码 4. 模式分析4.1 抽象工厂模式退化为工厂方法模式4.2 工厂方法模式退化为简单工厂模式 5. 模式特点5.1 优点5.2 缺点 6. 适用场景6.1 需要…