Skip to content

Commit

Permalink
Extend loop normalization (#705)
Browse files Browse the repository at this point in the history
Signed-off-by: Hernan Ponce de Leon <[email protected]>
Co-authored-by: Hernan Ponce de Leon <[email protected]>
Co-authored-by: Thomas Haas <[email protected]>
  • Loading branch information
3 people authored Aug 3, 2024
1 parent eb15f40 commit 6582d4a
Show file tree
Hide file tree
Showing 35 changed files with 657 additions and 22 deletions.
1 change: 0 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ export DAT3M_OUTPUT=$DAT3M_HOME/output
At least the following compiler flag needs to be set, further can be added (only to verify C programs)
```
export CFLAGS="-I$DAT3M_HOME/include"
export OPTFLAGS="-mem2reg -sroa -early-cse -indvars -loop-unroll -fix-irreducible -loop-simplify -simplifycfg -gvn"
```

If you are verifying C code, be sure `clang` is in your `PATH`.
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
21 changes: 21 additions & 0 deletions benchmarks/miscellaneous/jumpIntoLoop.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
#include <stdint.h>
#include <assert.h>
#include <dat3m.h>

volatile int32_t x = 0;

int main()
{
int i = 0;
int jumpIntoLoop = __VERIFIER_nondet_bool();
if (jumpIntoLoop) goto L;

__VERIFIER_loop_bound(6);
for (i = 1; i < 5; i++) {
L:
x++;
}

assert ((jumpIntoLoop && x == 5) || (!jumpIntoLoop && x == 4));
return 0;
}
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
58 changes: 58 additions & 0 deletions benchmarks/miscellaneous/unsupported-loop-normalization.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
#include <assert.h>
#include <dat3m.h>

int main()
{
unsigned int x = __VERIFIER_nondet_uint();
A:
if (x >= 1) {
x = 4;
goto B;
} else {
goto C;
}
B:
// 3, 4, 5, 6
if (x > 3) {
goto D;
} else {
goto E;
}
D:
// 3, 4, 5, 6
x++;

if (x > 5) {
goto Halt;
} else {
goto E;
}
E:
// 3, 4, 5
if (x < 4) {
goto D;
} else {
goto F;
}
F:
// 4, 5
if (x < 3) {
goto C;
} else {
goto G;
}
G:
// 0, 1, 2, 4, 5
x++;
goto C;
C:
// 0, 1, 2, 3, 5, 6
if (x > 2) {
goto B;
} else {
goto G;
}
Halt:
// 6, 7
assert(x == 6);
}
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -1,46 +1,97 @@
package com.dat3m.dartagnan.program.processing;

import com.dat3m.dartagnan.expression.Expression;
import com.dat3m.dartagnan.expression.ExpressionFactory;
import com.dat3m.dartagnan.expression.type.IntegerType;
import com.dat3m.dartagnan.expression.type.TypeFactory;
import com.dat3m.dartagnan.program.Function;
import com.dat3m.dartagnan.program.Register;
import com.dat3m.dartagnan.program.analysis.SyntacticContextAnalysis;
import com.dat3m.dartagnan.program.event.Event;
import com.dat3m.dartagnan.program.event.EventFactory;
import com.dat3m.dartagnan.program.event.core.CondJump;
import com.dat3m.dartagnan.program.event.core.Label;
import com.dat3m.dartagnan.program.event.core.Local;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;

/*
This pass transforms loops to have a single backjump using forward jumping.
Given a loop of the form
This pass normalizes loops to have a single unconditional backjump and a single entry point.
It achieves this via two transformations.
L:
...
if X goto L
...
if Y goto L
More Code
(1) Given a loop of the form
entry:
...
goto C
...
L:
...
C:
...
goto L
it transforms it to
L:
...
if X goto __repeatLoop_L
...
if Y goto __repeatLoop_L
goto __breakLoop_L
__repeatLoop_L:
goto L
__breakLoop_L
More Code
...
entry:
__loopEntryPoint_L <- 0
...
__loopEntryPoint_L <- 1
goto L
...
L:
__forwardTo_L <- __loopEntryPoint_L
__loopEntryPoint_L <- 0
if __forwardTo_L == 1 goto C
...
C:
...
goto L
(2) Given a loop of the form
L:
...
if X goto L
...
if Y goto L
More Code
it transforms it to
L:
...
if X goto __repeatLoop_L
...
if Y goto __repeatLoop_L
goto __breakLoop_L
__repeatLoop_L:
goto L
__breakLoop_L
More Code
...
*/
public class NormalizeLoops implements FunctionProcessor {

private final TypeFactory types = TypeFactory.getInstance();
private final ExpressionFactory expressions = ExpressionFactory.getInstance();

public static NormalizeLoops newInstance() {
return new NormalizeLoops();
}

@Override
public void run(Function function) {

guaranteeSingleEntry(function);
IdReassignment.newInstance().run(function);
guaranteeSingleUnconditionalBackjump(function);

}

private void guaranteeSingleUnconditionalBackjump(Function function) {
int counter = 0;
for (Label label : function.getEvents(Label.class)) {
final List<CondJump> backJumps = label.getJumpSet().stream()
Expand Down Expand Up @@ -70,4 +121,83 @@ public void run(Function function) {
counter++;
}
}

private void guaranteeSingleEntry(Function function) {
int loopCounter = 0;
for (Label loopBegin : function.getEvents(Label.class)) {
final List<CondJump> backJumps = loopBegin.getJumpSet().stream()
.filter(j -> j.getLocalId() > loopBegin.getLocalId())
.sorted()
.toList();
if (backJumps.isEmpty()) {
continue;
}

final CondJump loopEnd = backJumps.get(backJumps.size() - 1);
final List<Label> loopIrregularEntryPoints = loopBegin.getSuccessor().getSuccessors().stream()
.takeWhile(ev -> ev != loopEnd)
.filter(Label.class::isInstance).map(Label.class::cast)
.filter(l -> isEntryPoint(loopBegin, loopEnd, l))
.toList();

if (loopIrregularEntryPoints.isEmpty()) {
continue;
}

final IntegerType entryPointType = types.getByteType();
final Register entryPointReg = function.newRegister(String.format("__loopEntryPoint_%s#%s", loopBegin, loopCounter), entryPointType);
final Register forwarderReg = function.newRegister(String.format("__forwardTo_%s#%s", loopBegin, loopCounter), entryPointType);
final Local initEntryPointReg = EventFactory.newLocal(entryPointReg, expressions.makeZero(entryPointType));
final Local assignForwarderReg = EventFactory.newLocal(forwarderReg, entryPointReg);
final Local resetEntryPointReg = EventFactory.newLocal(entryPointReg, expressions.makeZero(entryPointType));
function.getEntry().insertAfter(initEntryPointReg);

final List<Event> forwardingInstrumentation = new ArrayList<>();
forwardingInstrumentation.add(assignForwarderReg);
forwardingInstrumentation.add(resetEntryPointReg);

int counter = 0;
for (Label entryPoint : loopIrregularEntryPoints) {
final List<CondJump> enteringJumps = getEnteringJumps(loopBegin,loopEnd, entryPoint);
assert (!enteringJumps.isEmpty());

final Expression entryPointValue = expressions.makeValue(++counter, entryPointType);
for (CondJump enteringJump : enteringJumps) {
if (enteringJump.getLocalId() > loopEnd.getLocalId()) {
// TODO: This case is rare as it would imply we have two (or more) overlapping loops.
// In this case, we should first merge the overlapping loops into one large loop.
final String error = String.format("Cannot normalize loop with loop-entering backjump (overlapping loops?): %d:%s \t %s",
enteringJump.getLocalId(), enteringJump, SyntacticContextAnalysis.getSourceLocationString(enteringJump));
throw new UnsupportedOperationException(error);
}
if (!enteringJump.isGoto()) {
// TODO: We should support this case, but the current implementation is wrong
// because if an instrumented jump is not taken, it still updates the entry point reg
// which will never get reset: we would end up accidentally forwarding a regular loop entry.
final String error = String.format("Cannot normalize loop with conditional loop-entering jump: %d:%s \t %s",
enteringJump.getLocalId(), enteringJump, SyntacticContextAnalysis.getSourceLocationString(enteringJump));
throw new UnsupportedOperationException(error);
}
enteringJump.getPredecessor().insertAfter(EventFactory.newLocal(entryPointReg, entryPointValue));
enteringJump.updateReferences(Map.of(entryPoint, loopBegin));
}

final CondJump forwardingJump = EventFactory.newJump(expressions.makeEQ(forwarderReg, entryPointValue), entryPoint);
forwardingInstrumentation.add(forwardingJump);
}

loopBegin.insertAfter(forwardingInstrumentation);
loopCounter++;
}
}

private boolean isEntryPoint(Event beginning, Event end, Label internal) {
return !getEnteringJumps(beginning, end, internal).isEmpty();
}

private List<CondJump> getEnteringJumps(Event beginning, Event end, Label internal) {
return internal.getJumpSet().stream()
.filter(j -> j.getLocalId() < beginning.getLocalId() || end.getLocalId() < j.getLocalId())
.toList();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@ private ProcessingManager(Configuration config) throws InvalidConfigurationExcep
final FunctionProcessor removeDeadJumps = RemoveDeadCondJumps.fromConfig(config);
programProcessors.addAll(Arrays.asList(
printBeforeProcessing ? DebugPrint.withHeader("Before processing", Printer.Mode.ALL) : null,
ProgramProcessor.fromFunctionProcessor(NormalizeLoops.newInstance(), Target.FUNCTIONS, true),
intrinsics.markIntrinsicsPass(),
GEPToAddition.newInstance(),
NaiveDevirtualisation.newInstance(),
Expand All @@ -99,6 +98,7 @@ private ProcessingManager(Configuration config) throws InvalidConfigurationExcep
Simplifier.fromConfig(config)
), Target.FUNCTIONS, true
),
ProgramProcessor.fromFunctionProcessor(NormalizeLoops.newInstance(), Target.FUNCTIONS, true),
RegisterDecomposition.newInstance(),
RemoveDeadFunctions.newInstance(),
printAfterSimplification ? DebugPrint.withHeader("After simplification", Printer.Mode.ALL) : null,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,8 @@ public static Iterable<Object[]> data() throws IOException {
{"staticLoops", IMM, PASS, 1},
{"offsetof", IMM, PASS, 1},
{"ctlz", IMM, PASS, 1},
{"cttz", IMM, PASS, 1}
{"cttz", IMM, PASS, 1},
{"jumpIntoLoop", IMM, PASS, 11}
});
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,17 +1,21 @@
package com.dat3m.dartagnan.exceptions;

import com.dat3m.dartagnan.exception.MalformedProgramException;
import java.lang.UnsupportedOperationException;
import com.dat3m.dartagnan.expression.Expression;
import com.dat3m.dartagnan.expression.ExpressionFactory;
import com.dat3m.dartagnan.expression.type.TypeFactory;
import com.dat3m.dartagnan.parsers.program.ProgramParser;
import com.dat3m.dartagnan.parsers.program.utils.ProgramBuilder;
import com.dat3m.dartagnan.program.Function;
import com.dat3m.dartagnan.program.Program;
import com.dat3m.dartagnan.program.Program.SourceLanguage;
import com.dat3m.dartagnan.program.Thread;
import com.dat3m.dartagnan.program.analysis.BranchEquivalence;
import com.dat3m.dartagnan.program.event.EventFactory;
import com.dat3m.dartagnan.program.event.core.Skip;
import com.dat3m.dartagnan.program.processing.NormalizeLoops;

import org.junit.Test;
import org.sosy_lab.common.configuration.Configuration;

Expand Down Expand Up @@ -83,4 +87,11 @@ public void LocationNotInitialized() throws Exception {
public void RegisterNotInitialized() throws Exception {
new ProgramParser().parse(new File(getTestResourcePath("exceptions/RegisterNotInitialized.litmus")));
}

@Test(expected = UnsupportedOperationException.class)
public void UnsupportedLoopNormalization() throws Exception {
Program p = new ProgramParser().parse(new File(getTestResourcePath("exceptions/unsupported-loop-normalization.ll")));
Function main = p.getFunctionByName("main").get();
NormalizeLoops.newInstance().run(main);
}
}
Loading

0 comments on commit 6582d4a

Please sign in to comment.