对拍程序-java


大部分OJ除了给出测试用例,其他用例一般都不会给出,拿不到自己WA的用例,就比较难分析出自己的程序问题是出在那里。对拍程序能较大程度的解决这个问题。

对拍,就是两份程序同时读入相同的数据,读入的数据都是我们用随机数生成器生成符合要求的数据,然后分别输出,看两份程序的输出是否一致。

对拍具体分为两种,一种是比赛中的对拍,此时的对拍,一份程序是逻辑正确的暴力程序,另一份是时间、内存合理的程序,但不确定逻辑是否正确。这时可以将暴力程序和符合要求的程序对拍,若对拍的程序多次输出一致,则大概率没有问题。
另一种对拍是比较常用的平常练习时的对拍,某道题提交后,OJ会告知提交WA了,但是不给出WA的用例,于是我们就比较难定位程序中的问题,这时对拍就派上用场了,一般的OJ都可以看到他人AC的代码,我们将AC的代码与自己的代码对拍,当两份程序对同一份输入产生的输出不同时,那么我们就找到了一个导致自己程序WA的用例。

以下是java实现的一份对拍程序,用来生成数据并比较两个算法程序的输出是否一致。

CompareCode类是一个对拍程序,用来对拍A类与B类的同样的输入,是否输出一致。随机数生成用的是generateTestCases方法,不同的题目需要改变此方法,使生成的数据格式符合题目要求即可。其他A类或B类放入AC的程序,另一份放入自己的程序。若A类是AC的程序,则对于生成的用例testcase1.txt,文件outputA1.txt是AC程序的输出文件。outputB1.txt是自己程序的输出文件。
此对拍程序通过反射实现,改变输入流(生成的测试用例)、输出流到文件(程序的产生的输出),若输出不一致,则会在控制台用警告的方式提示,哪些随机生成的用例的输出不一致。

基于jdk11的readString方法
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import util.RandomArrList; // 自定义的随机数生成器

import java.io.*;
import java.lang.reflect.Method;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.util.ArrayList;

/**
* 对拍程序
*/
public class CompareCode {

private static final String TEST_CASES_FOLDER = "./"; // 当前目录

public static void main(String[] args) throws IOException {
// 要对拍多少次。即测试用例文件生成的个数
int numCase = 1;

// 生成测试用例数据并写入文件
generateTestCases(numCase);

// 调用算法 A 和算法 B 进行对拍
for (int i = 1; i <= numCase; i++) {
String inputFilePath = TEST_CASES_FOLDER + "testcase" + i + ".txt";
String outputAFilePath = TEST_CASES_FOLDER + "outputA" + i + ".txt";
String outputBFilePath = TEST_CASES_FOLDER + "outputB" + i + ".txt";

runAlgorithm(A.class, inputFilePath, outputAFilePath);
runAlgorithm(B.class, inputFilePath, outputBFilePath);

if (!compareOutputs(outputAFilePath, outputBFilePath)) {
System.setOut(System.err);
System.out.println("Output mismatch for testcase" + i);
System.out.println("A: " + outputAFilePath);
System.out.println("B: " + outputBFilePath);
System.setOut(System.out);
System.out.println("============================================");
System.out.println();
}
}
}

// 调用程序处理输入文件并将输出结果写入文件
private static void runAlgorithm(Class<?> clazz, String inputFilePath, String outputFilePath) {
try {
Method method = clazz.getMethod("main", String[].class);

// 读取输入文件内容
String inputFileContent = Files.readString(Paths.get(inputFilePath));

// 设置 System.in 为读取输入文件内容的流
ByteArrayInputStream byteArrayInputStream = new ByteArrayInputStream(inputFileContent.getBytes());
System.setIn(byteArrayInputStream);

// 设置 System.out 为输出到文件
PrintStream fileOut = new PrintStream(new FileOutputStream(outputFilePath));
System.setOut(fileOut);

// 调用算法的 main 方法,并将空的参数数组传递给算法
method.invoke(null, new Object[]{new String[]{}});

// 恢复 System.in 和 System.out
System.setIn(System.in);
System.setOut(System.out);
} catch (Exception e) {
e.printStackTrace();
}
}

// 比较两个输出文件的内容是否一致
private static boolean compareOutputs(String outputFilePath1, String outputFilePath2) throws IOException {
String content1 = Files.readString(Paths.get(outputFilePath1));
String content2 = Files.readString(Paths.get(outputFilePath2));

return content1.equals(content2);
}


// 生成测试用例数据并写入文件
// 测试用例格式1:一行一个整数n,n从3到10,下一行n个数,大小在1到n范围,用空格隔开
private static void generateTestCases1(int numTestCases) throws IOException {
for (int i = 1; i <= numTestCases; i++) {
String filePath = TEST_CASES_FOLDER + "testcase" + i + ".txt";
try (BufferedWriter writer = new BufferedWriter(new FileWriter(filePath))) {
int n = RandomArrList.getRandNum(3, 10);
writer.write(n + "\n");
int[] randomArr = RandomArrList.getRandomArr(n, 1, n);
for (int j = 0; j < randomArr.length; j++) {
writer.write(randomArr[j] + " ");
}
writer.write("\n");
}
}
}

// 生成测试用例数据并写入文件
// 测试用例格式:一行一个整数t,表示有t组数据,接下来t组,每组第一行一个整数n、m,属于1到15.。。。。
private static void generateTestCases(int numTestCases) throws IOException {

for (int i = 1; i <= numTestCases; i++) {
String filePath = TEST_CASES_FOLDER + "testcase" + i + ".txt";
try (BufferedWriter writer = new BufferedWriter(new FileWriter(filePath))) {
int t = RandomArrList.getRandNum(1, 10);
writer.write(t + "\n");
for (int j = 0; j < t; j++) {
int n = RandomArrList.getRandNum(1, 15);
int m = RandomArrList.getRandNum(1, 15);
writer.write(n + " " + m);
writer.write("\n");
ArrayList<Integer> randomList = RandomArrList.getRandomList(n, 1, 10);
for (Integer num : randomList) {
writer.write(num + " ");
}
writer.write("\n");
}
}
}
}
}

jdk8版本代码,替换掉Files.readString()

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import util.RandomArrList;

import java.io.*;
import java.lang.reflect.Method;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.List;

public class CompareCode {

private static final String TEST_CASES_FOLDER = "./"; // 当前目录

public static void main(String[] args) throws IOException {
int numCase = 1;

// 生成测试用例数据并写入文件
generateTestCases(numCase);

// 调用算法 A 和算法 B 进行对拍
for (int i = 1; i <= numCase; i++) {
String inputFilePath = TEST_CASES_FOLDER + "testcase" + i + ".txt";
String outputAFilePath = TEST_CASES_FOLDER + "outputA" + i + ".txt";
String outputBFilePath = TEST_CASES_FOLDER + "outputB" + i + ".txt";

runAlgorithm(A.class, inputFilePath, outputAFilePath);
runAlgorithm(B.class, inputFilePath, outputBFilePath);

if (!compareOutputs(outputAFilePath, outputBFilePath)) {
System.setOut(System.err);
System.out.println("Output mismatch for testcase" + i);
System.out.println("A: " + outputAFilePath);
System.out.println("B: " + outputBFilePath);
System.setOut(System.out);
System.out.println("============================================");
System.out.println();
}
}
}

// 调用程序处理输入文件并将输出结果写入文件
private static void runAlgorithm(Class<?> clazz, String inputFilePath, String outputFilePath) {
try {
Method method = clazz.getMethod("main", String[].class);

// 读取输入文件内容
List<String> inputFileContent = Files.readAllLines(Paths.get(inputFilePath));

// 设置 System.in 为读取输入文件内容的流
ByteArrayInputStream byteArrayInputStream = new ByteArrayInputStream(String.join(System.lineSeparator(), inputFileContent).getBytes(StandardCharsets.UTF_8));
System.setIn(byteArrayInputStream);

// 设置 System.out 为输出到文件
PrintStream fileOut = new PrintStream(new FileOutputStream(outputFilePath));
System.setOut(fileOut);

// 调用算法的 main 方法,并将空的参数数组传递给算法
method.invoke(null, new Object[]{new String[]{}});

// 恢复 System.in 和 System.out
System.setIn(System.in);
System.setOut(System.out);
} catch (Exception e) {
e.printStackTrace();
}
}

// 比较两个输出文件的内容是否一致
private static boolean compareOutputs(String outputFilePath1, String outputFilePath2) throws IOException {
byte[] bytes1 = Files.readAllBytes(Paths.get(outputFilePath1));
byte[] bytes2 = Files.readAllBytes(Paths.get(outputFilePath2));
String content1 = new String(bytes1, StandardCharsets.UTF_8);
String content2 = new String(bytes2, StandardCharsets.UTF_8);

return content1.equals(content2);
}


// 生成测试用例数据并写入文件
private static void generateTestCases1(int numTestCases) throws IOException {
for (int i = 1; i <= numTestCases; i++) {
String filePath = TEST_CASES_FOLDER + "testcase" + i + ".txt";
try (BufferedWriter writer = new BufferedWriter(new FileWriter(filePath))) {
int n = RandomArrList.getRandNum(3, 10);
writer.write(n + "\n");
int[] randomArr = RandomArrList.getRandomArr(n, 1, n);
for (int j = 0; j < randomArr.length; j++) {
writer.write(randomArr[j] + " ");
}
writer.write("\n");
}
}
}

// 生成测试用例数据并写入文件
private static void generateTestCases(int numTestCases) throws IOException {

for (int i = 1; i <= numTestCases; i++) {
String filePath = TEST_CASES_FOLDER + "testcase" + i + ".txt";
try (BufferedWriter writer = new BufferedWriter(new FileWriter(filePath))) {
int t = RandomArrList.getRandNum(1, 10);
writer.write(t + "\n");
for (int j = 0; j < t; j++) {
int n = RandomArrList.getRandNum(1, 15);
int m = RandomArrList.getRandNum(1, 15);
writer.write(n + " " + m);
writer.write("\n");
ArrayList<Integer> randomList = RandomArrList.getRandomList(n, 1, 10);
for (Integer num : randomList) {
writer.write(num + " ");
}
writer.write("\n");
}
}
}
}
}

RandomArrList是我自己写的随机数生成器,代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Random;

/**
* 生成随机数组
*
**/
public class RandomArrList {

// 生成随机数数组,长度为length,最小数为start,最大数为end
public static int[] getRandomArr(int length, int start, int end) {
// T[] ts = (T[])Array.newInstance(componentType, length);
int[] resArr = new int[length];
Random random = new Random();
for (int i = 0; i < length; i++) {
int num = random.nextInt(end - start + 1) + start;
resArr[i] = num;
}
return resArr;
}

// 生成随机数集合,长度为length,最小数为start,最大数为end
public static ArrayList<Integer> getRandomList(int length, int start, int end) {
ArrayList<Integer> resList = new ArrayList<>(length);
Random random = new Random();
for (int i = 0; i < length; i++) {
int num = random.nextInt(end - start + 1) + start;
resList.add(num);
}
return resList;
}

// 生成单个随机数
public static int getRandNum(int start, int end) {
Random random = new Random();
return random.nextInt(end - start + 1) + start;
}
}