feat: update structure
This commit is contained in:
37
cs2030s/labs/Lab9/Lab9a/Lab9a.java
Normal file
37
cs2030s/labs/Lab9/Lab9a/Lab9a.java
Normal file
@@ -0,0 +1,37 @@
|
||||
import java.util.Scanner;
|
||||
import java.time.Instant;
|
||||
import java.time.Duration;
|
||||
|
||||
/**
|
||||
* Lab9a is the main driver class for testing matrix multiplication.
|
||||
* Usage: java Lab9a
|
||||
*/
|
||||
class Lab9a {
|
||||
public static void main(String[] args) {
|
||||
Scanner scanner = new Scanner(System.in);
|
||||
int n = scanner.nextInt();
|
||||
|
||||
// Read matrix 1
|
||||
Matrix m1 = new Matrix(n);
|
||||
for (int i = 0; i < n; i++) {
|
||||
for (int j = 0; j < n; j++) {
|
||||
m1.m[i][j] = scanner.nextDouble();
|
||||
}
|
||||
}
|
||||
|
||||
// Read matrix 1
|
||||
Matrix m2 = new Matrix(n);
|
||||
for (int i = 0; i < n; i++) {
|
||||
for (int j = 0; j < n; j++) {
|
||||
m2.m[i][j] = scanner.nextDouble();
|
||||
}
|
||||
}
|
||||
|
||||
// Multiply matrices
|
||||
long startTime = System.currentTimeMillis();
|
||||
Matrix res = Matrix.parallelMultiply(m1, m2);
|
||||
long endTime = System.currentTimeMillis();
|
||||
// System.out.println("Taken: " + (endTime - startTime));
|
||||
System.out.println(res);
|
||||
}
|
||||
}
|
||||
BIN
cs2030s/labs/Lab9/Lab9a/Lab9a.pdf
Normal file
BIN
cs2030s/labs/Lab9/Lab9a/Lab9a.pdf
Normal file
Binary file not shown.
223
cs2030s/labs/Lab9/Lab9a/Matrix.java
Normal file
223
cs2030s/labs/Lab9/Lab9a/Matrix.java
Normal file
@@ -0,0 +1,223 @@
|
||||
import java.util.function.Supplier;
|
||||
import java.lang.StringBuilder;
|
||||
import java.util.concurrent.ForkJoinPool;
|
||||
import java.util.concurrent.RecursiveTask;
|
||||
import java.util.concurrent.ForkJoinTask;
|
||||
import java.lang.Runtime;
|
||||
|
||||
/**
|
||||
* Encapsulate a square matrix of double values.
|
||||
*/
|
||||
class Matrix {
|
||||
/**
|
||||
* 2D square array of double values, storing the matrix.
|
||||
*/
|
||||
double[][] m;
|
||||
/**
|
||||
* The number of columns and rows in the matrix.
|
||||
*/
|
||||
int dimension;
|
||||
|
||||
private static final int THRESHOLD = 2;
|
||||
|
||||
/**
|
||||
* Checks if two matrices are equals.
|
||||
*
|
||||
* @param m1 First matrices to check
|
||||
* @param m2 Second matrices to check against
|
||||
* @return true if every elements in m1 and m2 are the same; false otherwise.
|
||||
*/
|
||||
public static boolean equals(Matrix m1, Matrix m2) {
|
||||
if (m1.dimension != m2.dimension) {
|
||||
return false;
|
||||
}
|
||||
for (int i = 0; i < m1.dimension; i++) {
|
||||
for (int j = 0; j < m1.dimension; j++) {
|
||||
if (Math.abs(m1.m[i][j] - m2.m[i][j]) > 0.000001) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* A constructor for the matrix.
|
||||
*
|
||||
* @param dimension The number of rows.
|
||||
*/
|
||||
Matrix(int dimension) {
|
||||
this.dimension = dimension;
|
||||
this.m = new double[dimension][dimension];
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate a matrix of d x d according to the given supplier.
|
||||
*
|
||||
* @param dimension The dimension of the matrix
|
||||
* @param supplier The lambda to generate the matrix with.
|
||||
* @return The new matrix.
|
||||
*/
|
||||
static Matrix generate(int dimension, Supplier<Double> supplier) {
|
||||
Matrix matrix = new Matrix(dimension);
|
||||
for (int row = 0; row < dimension; row++) {
|
||||
for (int col = 0; col < dimension; col++) {
|
||||
matrix.m[row][col] = supplier.get();
|
||||
}
|
||||
}
|
||||
return matrix;
|
||||
}
|
||||
|
||||
/**
|
||||
* Return a string representation of the matrix, pretty-printed
|
||||
* with each row on a single line.
|
||||
*
|
||||
* @return The string representation of this matrix.
|
||||
*/
|
||||
public String toString() {
|
||||
StringBuilder s = new StringBuilder();
|
||||
for (int row = 0; row < dimension; row++) {
|
||||
for (int col = 0; col < dimension; col++) {
|
||||
s.append(String.format("%.3f", m[row][col]) + " ");
|
||||
}
|
||||
s.append("\n");
|
||||
}
|
||||
return s.toString();
|
||||
}
|
||||
|
||||
/**
|
||||
* Multiply matrix m with this matrix, return a new result matrix.
|
||||
*
|
||||
* @param m1 The matrix to multiply with.
|
||||
* @param m2 The matrix to multiply with.
|
||||
* @param m1Row The starting row of m1.
|
||||
* @param m1Col The starting col of m1.
|
||||
* @param m2Row The starting row of m2.
|
||||
* @param m2Col The starting col of m2.
|
||||
* @param dimension The dimension of the input (sub)-matrices and the size
|
||||
* of the output matrix.
|
||||
* @return The new matrix.
|
||||
*/
|
||||
public static Matrix nonRecursiveMultiply(Matrix m1, Matrix m2,
|
||||
int m1Row, int m1Col, int m2Row, int m2Col, int dimension) {
|
||||
Matrix result = new Matrix(dimension);
|
||||
for (int row = 0; row < dimension; row++) {
|
||||
for (int col = 0; col < dimension; col++) {
|
||||
double sum = 0;
|
||||
// multiply row to col
|
||||
for (int i = 0; i < dimension; i++) {
|
||||
sum += m1.m[row + m1Row][i + m1Col] * m2.m[i + m2Row][col + m2Col];
|
||||
}
|
||||
result.m[row][col] = sum;
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* Multiple two matrices non-recursively.
|
||||
*
|
||||
* @param m1 The matrix to multiply with.
|
||||
* @param m2 The matrix to multiply with.
|
||||
* @return The resulting matrix m1 * m2
|
||||
*/
|
||||
public static Matrix nonRecursiveMultiply(Matrix m1, Matrix m2) {
|
||||
return Matrix.nonRecursiveMultiply(m1, m2, 0, 0, 0, 0, m1.dimension);
|
||||
}
|
||||
|
||||
/**
|
||||
* Multiply matrix m with this matrix, return a new result matrix.
|
||||
*
|
||||
* @param m1 The matrix to multiply with.
|
||||
* @param m2 The matrix to multiply with.
|
||||
* @param m1Row The starting row of m1.
|
||||
* @param m1Col The starting col of m1.
|
||||
* @param m2Row The starting row of m2.
|
||||
* @param m2Col The starting col of m2.
|
||||
* @param dimension The dimension of the input (sub)-matrices and the size
|
||||
* of the output matrix.
|
||||
* @return The resulting matrix m1 * m2
|
||||
*/
|
||||
public static Matrix recursiveMultiply(Matrix m1, Matrix m2,
|
||||
int m1Row, int m1Col, int m2Row, int m2Col, int dimension) {
|
||||
|
||||
// If the matrix is small enough, just multiple non-recursively.
|
||||
if (dimension <= THRESHOLD) {
|
||||
return Matrix.nonRecursiveMultiply(m1, m2, m1Row, m1Col, m2Row, m2Col, dimension);
|
||||
}
|
||||
|
||||
// Else, cut the matrix into four blocks of equal size, recursively
|
||||
// multiply then sum the multiplication result.
|
||||
int size = dimension / 2;
|
||||
Matrix result = new Matrix(dimension);
|
||||
Matrix a11b11 = recursiveMultiply(m1, m2, m1Row, m1Col, m2Row, m2Col, size);
|
||||
Matrix a12b21 = recursiveMultiply(m1, m2, m1Row, m1Col + size, m2Row + size, m2Col, size);
|
||||
for (int i = 0; i < size; i++) {
|
||||
double[] m1m = a11b11.m[i];
|
||||
double[] m2m = a12b21.m[i];
|
||||
double[] r1m = result.m[i];
|
||||
for (int j = 0; j < size; j++) {
|
||||
r1m[j] = m1m[j] + m2m[j];
|
||||
}
|
||||
}
|
||||
|
||||
Matrix a11b12 = recursiveMultiply(m1, m2, m1Row, m1Col, m2Row, m2Col + size, size);
|
||||
Matrix a12b22 = recursiveMultiply(m1, m2, m1Row, m1Col + size, m2Row + size, m2Col + size, size);
|
||||
for (int i = 0; i < size; i++) {
|
||||
double[] m1m = a11b12.m[i];
|
||||
double[] m2m = a12b22.m[i];
|
||||
double[] r1m = result.m[i];
|
||||
for (int j = 0; j < size; j++) {
|
||||
r1m[j + size] = m1m[j] + m2m[j];
|
||||
}
|
||||
}
|
||||
|
||||
Matrix a21b11 = recursiveMultiply(m1, m2, m1Row + size, m1Col, m2Row, m2Col, size);
|
||||
Matrix a22b21 = recursiveMultiply(m1, m2, m1Row + size, m1Col + size, m2Row + size, m2Col, size);
|
||||
for (int i = 0; i < size; i++) {
|
||||
double[] m1m = a21b11.m[i];
|
||||
double[] m2m = a22b21.m[i];
|
||||
double[] r1m = result.m[i + size];
|
||||
for (int j = 0; j < size; j++) {
|
||||
r1m[j] = m1m[j] + m2m[j];
|
||||
}
|
||||
}
|
||||
|
||||
Matrix a21b12 = recursiveMultiply(m1, m2, m1Row + size, m1Col, m2Row, m2Col + size, size);
|
||||
Matrix a22b22 = recursiveMultiply(m1, m2, m1Row + size, m1Col + size, m2Row + size, m2Col + size, size);
|
||||
for (int i = 0; i < size; i++) {
|
||||
double[] m1m = a21b12.m[i];
|
||||
double[] m2m = a22b22.m[i];
|
||||
double[] r1m = result.m[i + size];
|
||||
for (int j = 0; j < size; j++) {
|
||||
r1m[j + size] = m1m[j] + m2m[j];
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* Multiple two matrices recursively but sequentially with
|
||||
* divide-and-conquer algorithm.
|
||||
*
|
||||
* @param m1 The matrix to multiply with.
|
||||
* @param m2 The matrix to multiply with.
|
||||
* @return The resulting matrix m1 * m2
|
||||
*/
|
||||
public static Matrix recursiveMultiply(Matrix m1, Matrix m2) {
|
||||
return Matrix.recursiveMultiply(m1, m2, 0, 0, 0, 0, m1.dimension);
|
||||
}
|
||||
|
||||
/**
|
||||
* Multiple two matrices recursively and parallely with
|
||||
* divide-and-conquer algorithm.
|
||||
*
|
||||
* @param m1 The matrix to multiply with.
|
||||
* @param m2 The matrix to multiply with.
|
||||
* @return The resulting matrix m1 * m2
|
||||
*/
|
||||
public static Matrix parallelMultiply(Matrix m1, Matrix m2) {
|
||||
return new MatrixMultiplication(m1, m2, 0, 0, 0, 0, m1.dimension)
|
||||
.compute();
|
||||
}
|
||||
}
|
||||
133
cs2030s/labs/Lab9/Lab9a/MatrixMultiplication.java
Normal file
133
cs2030s/labs/Lab9/Lab9a/MatrixMultiplication.java
Normal file
@@ -0,0 +1,133 @@
|
||||
import java.util.concurrent.ForkJoinTask;
|
||||
import java.util.concurrent.RecursiveTask;
|
||||
|
||||
class MatrixMultiplication extends RecursiveTask<Matrix> {
|
||||
|
||||
/** The fork threshold. */
|
||||
private static final int FORK_THRESHOLD = 1024; // Find a good threshold
|
||||
|
||||
/** The first matrix to multiply with. */
|
||||
private final Matrix m1;
|
||||
|
||||
/** The second matrix to multiply with. */
|
||||
private final Matrix m2;
|
||||
|
||||
/** The starting row of m1. */
|
||||
private final int m1Row;
|
||||
|
||||
/** The starting col of m1. */
|
||||
private final int m1Col;
|
||||
|
||||
/** The starting row of m2. */
|
||||
private final int m2Row;
|
||||
|
||||
/** The starting col of m2. */
|
||||
private final int m2Col;
|
||||
|
||||
/**
|
||||
* The dimension of the input (sub)-matrices and the size of the output
|
||||
* matrix.
|
||||
*/
|
||||
private int dimension;
|
||||
|
||||
/**
|
||||
* A constructor for the Matrix Multiplication class.
|
||||
*
|
||||
* @param m1 The matrix to multiply with.
|
||||
* @param m2 The matrix to multiply with.
|
||||
* @param m1Row The starting row of m1.
|
||||
* @param m1Col The starting col of m1.
|
||||
* @param m2Row The starting row of m2.
|
||||
* @param m2Col The starting col of m2.
|
||||
* @param dimension The dimension of the input (sub)-matrices and the size
|
||||
* of the output matrix.
|
||||
*/
|
||||
MatrixMultiplication(Matrix m1, Matrix m2, int m1Row, int m1Col, int m2Row,
|
||||
int m2Col, int dimension) {
|
||||
this.m1 = m1;
|
||||
this.m2 = m2;
|
||||
this.m1Row = m1Row;
|
||||
this.m1Col = m1Col;
|
||||
this.m2Row = m2Row;
|
||||
this.m2Col = m2Col;
|
||||
this.dimension = dimension;
|
||||
}
|
||||
|
||||
public static final int THRESHOLD = 2;
|
||||
|
||||
@Override
|
||||
public Matrix compute() {
|
||||
// Modify this
|
||||
if (dimension <= THRESHOLD) {
|
||||
return Matrix.nonRecursiveMultiply(m1, m2, m1Row, m1Col, m2Row, m2Col, dimension);
|
||||
}
|
||||
|
||||
int size = dimension / 2;
|
||||
Matrix result = new Matrix(dimension);
|
||||
|
||||
ForkJoinTask<Matrix> a11b11Fork = new MatrixMultiplication(m1, m2, m1Row, m1Col, m2Row, m2Col, size).fork();
|
||||
ForkJoinTask<Matrix> a12b21Fork = new MatrixMultiplication(m1, m2, m1Row, m1Col + size, m2Row + size, m2Col, size)
|
||||
.fork();
|
||||
|
||||
ForkJoinTask<Matrix> a11b12Fork = new MatrixMultiplication(m1, m2, m1Row, m1Col, m2Row, m2Col + size, size).fork();
|
||||
ForkJoinTask<Matrix> a12b22Fork = new MatrixMultiplication(m1, m2, m1Row, m1Col + size, m2Row + size, m2Col + size,
|
||||
size).fork();
|
||||
|
||||
ForkJoinTask<Matrix> a21b11Fork = new MatrixMultiplication(m1, m2, m1Row + size, m1Col, m2Row, m2Col, size).fork();
|
||||
ForkJoinTask<Matrix> a22b21Fork = new MatrixMultiplication(m1, m2, m1Row + size, m1Col + size, m2Row + size, m2Col,
|
||||
size).fork();
|
||||
|
||||
ForkJoinTask<Matrix> a21b12Fork = new MatrixMultiplication(m1, m2, m1Row + size, m1Col, m2Row, m2Col + size, size)
|
||||
.fork();
|
||||
ForkJoinTask<Matrix> a22b22Fork = new MatrixMultiplication(m1, m2, m1Row + size, m1Col + size, m2Row + size,
|
||||
m2Col + size, size).fork();
|
||||
|
||||
Matrix a11b11 = a11b11Fork.join();
|
||||
Matrix a12b21 = a12b21Fork.join();
|
||||
for (int i = 0; i < size; i++) {
|
||||
double[] m1m = a11b11.m[i];
|
||||
double[] m2m = a12b21.m[i];
|
||||
double[] r1m = result.m[i];
|
||||
for (int j = 0; j < size; j++) {
|
||||
r1m[j] = m1m[j] + m2m[j];
|
||||
}
|
||||
}
|
||||
|
||||
Matrix a11b12 = a11b12Fork.join();
|
||||
Matrix a12b22 = a12b22Fork.join();
|
||||
for (int i = 0; i < size; i++) {
|
||||
double[] m1m = a11b12.m[i];
|
||||
double[] m2m = a12b22.m[i];
|
||||
double[] r1m = result.m[i];
|
||||
for (int j = 0; j < size; j++) {
|
||||
r1m[j + size] = m1m[j] + m2m[j];
|
||||
}
|
||||
}
|
||||
|
||||
Matrix a21b11 = a21b11Fork.join();
|
||||
Matrix a22b21 = a22b21Fork.join();
|
||||
|
||||
for (int i = 0; i < size; i++) {
|
||||
double[] m1m = a21b11.m[i];
|
||||
double[] m2m = a22b21.m[i];
|
||||
double[] r1m = result.m[i + size];
|
||||
for (int j = 0; j < size; j++) {
|
||||
r1m[j] = m1m[j] + m2m[j];
|
||||
}
|
||||
}
|
||||
|
||||
Matrix a21b12 = a21b12Fork.join();
|
||||
Matrix a22b22 = a22b22Fork.join();
|
||||
for (int i = 0; i < size; i++) {
|
||||
double[] m1m = a21b12.m[i];
|
||||
double[] m2m = a22b22.m[i];
|
||||
double[] r1m = result.m[i + size];
|
||||
for (int j = 0; j < size; j++) {
|
||||
r1m[j + size] = m1m[j] + m2m[j];
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
}
|
||||
1027
cs2030s/labs/Lab9/Lab9a/input/Test1.in
Normal file
1027
cs2030s/labs/Lab9/Lab9a/input/Test1.in
Normal file
File diff suppressed because one or more lines are too long
1027
cs2030s/labs/Lab9/Lab9a/input/Test2.in
Normal file
1027
cs2030s/labs/Lab9/Lab9a/input/Test2.in
Normal file
File diff suppressed because one or more lines are too long
1027
cs2030s/labs/Lab9/Lab9a/input/Test3.in
Normal file
1027
cs2030s/labs/Lab9/Lab9a/input/Test3.in
Normal file
File diff suppressed because one or more lines are too long
1027
cs2030s/labs/Lab9/Lab9a/input/Test4.in
Normal file
1027
cs2030s/labs/Lab9/Lab9a/input/Test4.in
Normal file
File diff suppressed because one or more lines are too long
1027
cs2030s/labs/Lab9/Lab9a/input/Test5.in
Normal file
1027
cs2030s/labs/Lab9/Lab9a/input/Test5.in
Normal file
File diff suppressed because one or more lines are too long
Reference in New Issue
Block a user