Skip to content

Commit

Permalink
Merge pull request #17949 from paldepind/rust-async-blocks
Browse files Browse the repository at this point in the history
Rust: Handle async blocks in CFG and SSA
  • Loading branch information
hvitved authored Nov 13, 2024
2 parents 67684d1 + 274d942 commit 2bb5603
Show file tree
Hide file tree
Showing 18 changed files with 825 additions and 534 deletions.
2 changes: 1 addition & 1 deletion rust/ql/lib/codeql/rust/controlflow/BasicBlocks.qll
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ final class BasicBlock = BasicBlockImpl;
* without branches or joins.
*/
private class BasicBlockImpl extends TBasicBlockStart {
/** Gets the scope of this basic block. */
/** Gets the CFG scope of this basic block. */
CfgScope getScope() { result = this.getAPredecessor().getScope() }

/** Gets an immediate successor of this basic block, if any. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,12 @@ query predicate nonPostOrderExpr(Expr e, string cls) {
*/
query predicate scopeNoFirst(CfgScope scope) {
Consistency::scopeNoFirst(scope) and
not scope = any(Function f | not exists(f.getBody())) and
not scope = any(ClosureExpr c | not exists(c.getBody()))
not scope =
[
any(AstNode f | not f.(Function).hasBody()),
any(ClosureExpr c | not c.hasBody()),
any(AsyncBlockExpr b | not b.hasStmtList())
]
}

/** Holds if `be` is the `else` branch of a `let` statement that results in a panic. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ private module CfgInput implements InputSig<Location> {
class CfgScope = Scope::CfgScope;

CfgScope getCfgScope(AstNode n) {
result = n.getEnclosingCallable() and
result = n.getEnclosingCfgScope() and
Stages::CfgStage::ref()
}

Expand All @@ -44,12 +44,10 @@ private module CfgInput implements InputSig<Location> {
predicate successorTypeIsCondition(SuccessorType t) { t instanceof Cfg::BooleanSuccessor }

/** Holds if `first` is first executed when entering `scope`. */
predicate scopeFirst(CfgScope scope, AstNode first) {
first(scope.(CfgScopeTree).getFirstChildNode(), first)
}
predicate scopeFirst(CfgScope scope, AstNode first) { scope.scopeFirst(first) }

/** Holds if `scope` is exited when `last` finishes with completion `c`. */
predicate scopeLast(CfgScope scope, AstNode last, Completion c) { last(scope.getBody(), last, c) }
predicate scopeLast(CfgScope scope, AstNode last, Completion c) { scope.scopeLast(last, c) }
}

private module CfgSplittingInput implements SplittingInputSig<Location, CfgInput> {
Expand All @@ -71,14 +69,7 @@ private module CfgImpl =

import CfgImpl

class CfgScopeTree extends StandardTree, Scope::CfgScope {
override predicate first(AstNode first) { first = this }

override predicate last(AstNode last, Completion c) {
last = this and
completionIsValidFor(c, this)
}

class CallableScopeTree extends StandardTree, PreOrderTree, PostOrderTree, Scope::CallableScope {
override predicate propagatesAbnormal(AstNode child) { none() }

override AstNode getChildNode(int i) {
Expand Down Expand Up @@ -280,13 +271,23 @@ module ExprTrees {
}
}

private AstNode getBlockChildNode(BlockExpr b, int i) {
result = b.getStmtList().getStatement(i)
or
i = b.getStmtList().getNumberOfStatements() and
result = b.getStmtList().getTailExpr()
}

class AsyncBlockExprTree extends StandardTree, PreOrderTree, PostOrderTree, AsyncBlockExpr {
override AstNode getChildNode(int i) { result = getBlockChildNode(this, i) }

override predicate propagatesAbnormal(AstNode child) { none() }
}

class BlockExprTree extends StandardPostOrderTree, BlockExpr {
override AstNode getChildNode(int i) {
result = this.getStmtList().getStatement(i)
or
i = this.getStmtList().getNumberOfStatements() and
result = this.getStmtList().getTailExpr()
}
BlockExprTree() { not this.isAsync() }

override AstNode getChildNode(int i) { result = getBlockChildNode(this, i) }

override predicate propagatesAbnormal(AstNode child) { child = this.getChildNode(_) }
}
Expand Down
35 changes: 31 additions & 4 deletions rust/ql/lib/codeql/rust/controlflow/internal/Scope.qll
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,31 @@ private import codeql.rust.elements.internal.generated.ParentChild

/**
* A control-flow graph (CFG) scope.
*
* A CFG scope is a callable with a body.
*/
class CfgScope extends Callable {
CfgScope() {
abstract private class CfgScopeImpl extends AstNode {
/** Holds if `first` is executed first when entering `scope`. */
abstract predicate scopeFirst(AstNode first);

/** Holds if `scope` is exited when `last` finishes with completion `c`. */
abstract predicate scopeLast(AstNode last, Completion c);
}

final class CfgScope = CfgScopeImpl;

final class AsyncBlockScope extends CfgScopeImpl, AsyncBlockExpr instanceof ExprTrees::AsyncBlockExprTree
{
override predicate scopeFirst(AstNode first) { first(super.getFirstChildNode(), first) }

override predicate scopeLast(AstNode last, Completion c) {
last(super.getLastChildElement(), last, c)
}
}

/**
* A CFG scope for a callable (a function or a closure) with a body.
*/
final class CallableScope extends CfgScopeImpl, Callable {
CallableScope() {
// A function without a body corresponds to a trait method signature and
// should not have a CFG scope.
this.(Function).hasBody()
Expand All @@ -23,4 +43,11 @@ class CfgScope extends Callable {
or
result = this.(ClosureExpr).getBody()
}

override predicate scopeFirst(AstNode first) {
first(this.(CallableScopeTree).getFirstChildNode(), first)
}

/** Holds if `scope` is exited when `last` finishes with completion `c`. */
override predicate scopeLast(AstNode last, Completion c) { last(this.getBody(), last, c) }
}
6 changes: 3 additions & 3 deletions rust/ql/lib/codeql/rust/dataflow/internal/DataFlowImpl.qll
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ final class NormalCall extends DataFlowCall, TNormalCall {
override CallCfgNode asCall() { result = c }

override DataFlowCallable getEnclosingCallable() {
result = TCfgScope(c.getExpr().getEnclosingCallable())
result = TCfgScope(c.getExpr().getEnclosingCfgScope())
}

override string toString() { result = c.toString() }
Expand Down Expand Up @@ -136,7 +136,7 @@ module Node {

ExprNode() { this = TExprNode(n) }

override CfgScope getCfgScope() { result = this.asExpr().getEnclosingCallable() }
override CfgScope getCfgScope() { result = this.asExpr().getEnclosingCfgScope() }

override Location getLocation() { result = n.getExpr().getLocation() }

Expand All @@ -156,7 +156,7 @@ module Node {

ParameterNode() { this = TParameterNode(parameter) }

override CfgScope getCfgScope() { result = parameter.getParam().getEnclosingCallable() }
override CfgScope getCfgScope() { result = parameter.getParam().getEnclosingCfgScope() }

override Location getLocation() { result = parameter.getLocation() }

Expand Down
19 changes: 12 additions & 7 deletions rust/ql/lib/codeql/rust/dataflow/internal/SsaImpl.qll
Original file line number Diff line number Diff line change
Expand Up @@ -156,19 +156,19 @@ private predicate variableReadActual(BasicBlock bb, int i, Variable v) {
*/
pragma[noinline]
private predicate hasCapturedWrite(Variable v, Cfg::CfgScope scope) {
any(VariableWriteAccess write | write.getVariable() = v and scope = write.getEnclosingCallable+())
any(VariableWriteAccess write | write.getVariable() = v and scope = write.getEnclosingCfgScope+())
.isCapture()
}

/**
* Holds if `v` is read inside basic block `bb` at index `i`, which is in the
* immediate outer scope of `scope`.
* immediate outer CFG scope of `scope`.
*/
pragma[noinline]
private predicate variableReadActualInOuterScope(
BasicBlock bb, int i, Variable v, Cfg::CfgScope scope
) {
variableReadActual(bb, i, v) and bb.getScope() = scope.getEnclosingCallable()
variableReadActual(bb, i, v) and bb.getScope() = scope.getEnclosingCfgScope()
}

pragma[noinline]
Expand Down Expand Up @@ -263,7 +263,7 @@ private predicate readsCapturedVariable(BasicBlock bb, Variable v) {
*/
pragma[noinline]
private predicate hasCapturedRead(Variable v, Cfg::CfgScope scope) {
any(VariableReadAccess read | read.getVariable() = v and scope = read.getEnclosingCallable+())
any(VariableReadAccess read | read.getVariable() = v and scope = read.getEnclosingCfgScope+())
.isCapture()
}

Expand All @@ -273,14 +273,18 @@ private predicate hasCapturedRead(Variable v, Cfg::CfgScope scope) {
*/
pragma[noinline]
private predicate variableWriteInOuterScope(BasicBlock bb, int i, Variable v, Cfg::CfgScope scope) {
SsaInput::variableWrite(bb, i, v, _) and scope.getEnclosingCallable() = bb.getScope()
SsaInput::variableWrite(bb, i, v, _) and scope.getEnclosingCfgScope() = bb.getScope()
}

/** Holds if evaluating `e` jumps to the evaluation of a different CFG scope. */
private predicate isControlFlowJump(Expr e) { e instanceof CallExprBase or e instanceof AwaitExpr }

/**
* Holds if the call `call` at index `i` in basic block `bb` may reach
* a callable that reads captured variable `v`.
*/
private predicate capturedCallRead(CallExprBase call, BasicBlock bb, int i, Variable v) {
private predicate capturedCallRead(Expr call, BasicBlock bb, int i, Variable v) {
isControlFlowJump(call) and
exists(Cfg::CfgScope scope |
hasCapturedRead(v, scope) and
(
Expand All @@ -295,7 +299,8 @@ private predicate capturedCallRead(CallExprBase call, BasicBlock bb, int i, Vari
* Holds if the call `call` at index `i` in basic block `bb` may reach a callable
* that writes captured variable `v`.
*/
predicate capturedCallWrite(CallExprBase call, BasicBlock bb, int i, Variable v) {
predicate capturedCallWrite(Expr call, BasicBlock bb, int i, Variable v) {
isControlFlowJump(call) and
call = bb.getNode(i).getAstNode() and
exists(Cfg::CfgScope scope |
hasVariableReadWithCapturedWrite(bb, any(int j | j > i), v, scope)
Expand Down
17 changes: 17 additions & 0 deletions rust/ql/lib/codeql/rust/elements/AsyncBlockExpr.qll
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
/**
* This module provides the public class `AsyncBlockExpr`.
*/

private import codeql.rust.elements.BlockExpr

/**
* An async block expression. For example:
* ```rust
* async {
* let x = 42;
* }
* ```
*/
final class AsyncBlockExpr extends BlockExpr {
AsyncBlockExpr() { this.isAsync() }
}
12 changes: 12 additions & 0 deletions rust/ql/lib/codeql/rust/elements/internal/AstNodeImpl.qll
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
*/

private import codeql.rust.elements.internal.generated.AstNode
private import codeql.rust.controlflow.ControlFlowGraph

/**
* INTERNAL: This module contains the customizable definition of `AstNode` and should not
Expand Down Expand Up @@ -44,6 +45,17 @@ module Impl {
)
}

/** Gets the CFG scope that encloses this node, if any. */
cached
CfgScope getEnclosingCfgScope() {
exists(AstNode p | p = this.getParentNode() |
result = p
or
not p instanceof CfgScope and
result = p.getEnclosingCfgScope()
)
}

/** Holds if this node is inside a macro expansion. */
predicate isInMacroExpansion() {
this = any(MacroCall mc).getExpanded()
Expand Down
3 changes: 2 additions & 1 deletion rust/ql/lib/codeql/rust/elements/internal/VariableImpl.qll
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
private import rust
private import codeql.rust.controlflow.ControlFlowGraph
private import codeql.rust.elements.internal.generated.ParentChild
private import codeql.rust.elements.internal.PathExprBaseImpl::Impl as PathExprBaseImpl
private import codeql.rust.elements.internal.FormatTemplateVariableAccessImpl::Impl as FormatTemplateVariableAccessImpl
Expand Down Expand Up @@ -445,7 +446,7 @@ module Impl {
Variable getVariable() { result = v }

/** Holds if this access is a capture. */
predicate isCapture() { this.getEnclosingCallable() != v.getPat().getEnclosingCallable() }
predicate isCapture() { this.getEnclosingCfgScope() != v.getPat().getEnclosingCfgScope() }

override string toString() { result = name }

Expand Down
1 change: 1 addition & 0 deletions rust/ql/lib/rust.qll
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import codeql.Locations
import codeql.files.FileSystem
import codeql.rust.elements.AssignmentOperation
import codeql.rust.elements.LogicalOperation
import codeql.rust.elements.AsyncBlockExpr
import codeql.rust.elements.Variable
import codeql.rust.elements.NamedFormatArgument
import codeql.rust.elements.PositionalFormatArgument
Loading

0 comments on commit 2bb5603

Please sign in to comment.