feat: update structure

This commit is contained in:
2024-01-22 14:27:40 +08:00
parent 7836c9185c
commit 3544a28a2e
559 changed files with 120846 additions and 4102 deletions

View File

@@ -0,0 +1,37 @@
import java.util.Scanner;
import java.time.Instant;
import java.time.Duration;
/**
* Lab9a is the main driver class for testing matrix multiplication.
* Usage: java Lab9a
*/
class Lab9a {
public static void main(String[] args) {
Scanner scanner = new Scanner(System.in);
int n = scanner.nextInt();
// Read matrix 1
Matrix m1 = new Matrix(n);
for (int i = 0; i < n; i++) {
for (int j = 0; j < n; j++) {
m1.m[i][j] = scanner.nextDouble();
}
}
// Read matrix 1
Matrix m2 = new Matrix(n);
for (int i = 0; i < n; i++) {
for (int j = 0; j < n; j++) {
m2.m[i][j] = scanner.nextDouble();
}
}
// Multiply matrices
long startTime = System.currentTimeMillis();
Matrix res = Matrix.parallelMultiply(m1, m2);
long endTime = System.currentTimeMillis();
// System.out.println("Taken: " + (endTime - startTime));
System.out.println(res);
}
}

Binary file not shown.

View File

@@ -0,0 +1,223 @@
import java.util.function.Supplier;
import java.lang.StringBuilder;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.RecursiveTask;
import java.util.concurrent.ForkJoinTask;
import java.lang.Runtime;
/**
* Encapsulate a square matrix of double values.
*/
class Matrix {
/**
* 2D square array of double values, storing the matrix.
*/
double[][] m;
/**
* The number of columns and rows in the matrix.
*/
int dimension;
private static final int THRESHOLD = 2;
/**
* Checks if two matrices are equals.
*
* @param m1 First matrices to check
* @param m2 Second matrices to check against
* @return true if every elements in m1 and m2 are the same; false otherwise.
*/
public static boolean equals(Matrix m1, Matrix m2) {
if (m1.dimension != m2.dimension) {
return false;
}
for (int i = 0; i < m1.dimension; i++) {
for (int j = 0; j < m1.dimension; j++) {
if (Math.abs(m1.m[i][j] - m2.m[i][j]) > 0.000001) {
return false;
}
}
}
return true;
}
/**
* A constructor for the matrix.
*
* @param dimension The number of rows.
*/
Matrix(int dimension) {
this.dimension = dimension;
this.m = new double[dimension][dimension];
}
/**
* Generate a matrix of d x d according to the given supplier.
*
* @param dimension The dimension of the matrix
* @param supplier The lambda to generate the matrix with.
* @return The new matrix.
*/
static Matrix generate(int dimension, Supplier<Double> supplier) {
Matrix matrix = new Matrix(dimension);
for (int row = 0; row < dimension; row++) {
for (int col = 0; col < dimension; col++) {
matrix.m[row][col] = supplier.get();
}
}
return matrix;
}
/**
* Return a string representation of the matrix, pretty-printed
* with each row on a single line.
*
* @return The string representation of this matrix.
*/
public String toString() {
StringBuilder s = new StringBuilder();
for (int row = 0; row < dimension; row++) {
for (int col = 0; col < dimension; col++) {
s.append(String.format("%.3f", m[row][col]) + " ");
}
s.append("\n");
}
return s.toString();
}
/**
* Multiply matrix m with this matrix, return a new result matrix.
*
* @param m1 The matrix to multiply with.
* @param m2 The matrix to multiply with.
* @param m1Row The starting row of m1.
* @param m1Col The starting col of m1.
* @param m2Row The starting row of m2.
* @param m2Col The starting col of m2.
* @param dimension The dimension of the input (sub)-matrices and the size
* of the output matrix.
* @return The new matrix.
*/
public static Matrix nonRecursiveMultiply(Matrix m1, Matrix m2,
int m1Row, int m1Col, int m2Row, int m2Col, int dimension) {
Matrix result = new Matrix(dimension);
for (int row = 0; row < dimension; row++) {
for (int col = 0; col < dimension; col++) {
double sum = 0;
// multiply row to col
for (int i = 0; i < dimension; i++) {
sum += m1.m[row + m1Row][i + m1Col] * m2.m[i + m2Row][col + m2Col];
}
result.m[row][col] = sum;
}
}
return result;
}
/**
* Multiple two matrices non-recursively.
*
* @param m1 The matrix to multiply with.
* @param m2 The matrix to multiply with.
* @return The resulting matrix m1 * m2
*/
public static Matrix nonRecursiveMultiply(Matrix m1, Matrix m2) {
return Matrix.nonRecursiveMultiply(m1, m2, 0, 0, 0, 0, m1.dimension);
}
/**
* Multiply matrix m with this matrix, return a new result matrix.
*
* @param m1 The matrix to multiply with.
* @param m2 The matrix to multiply with.
* @param m1Row The starting row of m1.
* @param m1Col The starting col of m1.
* @param m2Row The starting row of m2.
* @param m2Col The starting col of m2.
* @param dimension The dimension of the input (sub)-matrices and the size
* of the output matrix.
* @return The resulting matrix m1 * m2
*/
public static Matrix recursiveMultiply(Matrix m1, Matrix m2,
int m1Row, int m1Col, int m2Row, int m2Col, int dimension) {
// If the matrix is small enough, just multiple non-recursively.
if (dimension <= THRESHOLD) {
return Matrix.nonRecursiveMultiply(m1, m2, m1Row, m1Col, m2Row, m2Col, dimension);
}
// Else, cut the matrix into four blocks of equal size, recursively
// multiply then sum the multiplication result.
int size = dimension / 2;
Matrix result = new Matrix(dimension);
Matrix a11b11 = recursiveMultiply(m1, m2, m1Row, m1Col, m2Row, m2Col, size);
Matrix a12b21 = recursiveMultiply(m1, m2, m1Row, m1Col + size, m2Row + size, m2Col, size);
for (int i = 0; i < size; i++) {
double[] m1m = a11b11.m[i];
double[] m2m = a12b21.m[i];
double[] r1m = result.m[i];
for (int j = 0; j < size; j++) {
r1m[j] = m1m[j] + m2m[j];
}
}
Matrix a11b12 = recursiveMultiply(m1, m2, m1Row, m1Col, m2Row, m2Col + size, size);
Matrix a12b22 = recursiveMultiply(m1, m2, m1Row, m1Col + size, m2Row + size, m2Col + size, size);
for (int i = 0; i < size; i++) {
double[] m1m = a11b12.m[i];
double[] m2m = a12b22.m[i];
double[] r1m = result.m[i];
for (int j = 0; j < size; j++) {
r1m[j + size] = m1m[j] + m2m[j];
}
}
Matrix a21b11 = recursiveMultiply(m1, m2, m1Row + size, m1Col, m2Row, m2Col, size);
Matrix a22b21 = recursiveMultiply(m1, m2, m1Row + size, m1Col + size, m2Row + size, m2Col, size);
for (int i = 0; i < size; i++) {
double[] m1m = a21b11.m[i];
double[] m2m = a22b21.m[i];
double[] r1m = result.m[i + size];
for (int j = 0; j < size; j++) {
r1m[j] = m1m[j] + m2m[j];
}
}
Matrix a21b12 = recursiveMultiply(m1, m2, m1Row + size, m1Col, m2Row, m2Col + size, size);
Matrix a22b22 = recursiveMultiply(m1, m2, m1Row + size, m1Col + size, m2Row + size, m2Col + size, size);
for (int i = 0; i < size; i++) {
double[] m1m = a21b12.m[i];
double[] m2m = a22b22.m[i];
double[] r1m = result.m[i + size];
for (int j = 0; j < size; j++) {
r1m[j + size] = m1m[j] + m2m[j];
}
}
return result;
}
/**
* Multiple two matrices recursively but sequentially with
* divide-and-conquer algorithm.
*
* @param m1 The matrix to multiply with.
* @param m2 The matrix to multiply with.
* @return The resulting matrix m1 * m2
*/
public static Matrix recursiveMultiply(Matrix m1, Matrix m2) {
return Matrix.recursiveMultiply(m1, m2, 0, 0, 0, 0, m1.dimension);
}
/**
* Multiple two matrices recursively and parallely with
* divide-and-conquer algorithm.
*
* @param m1 The matrix to multiply with.
* @param m2 The matrix to multiply with.
* @return The resulting matrix m1 * m2
*/
public static Matrix parallelMultiply(Matrix m1, Matrix m2) {
return new MatrixMultiplication(m1, m2, 0, 0, 0, 0, m1.dimension)
.compute();
}
}

View File

@@ -0,0 +1,133 @@
import java.util.concurrent.ForkJoinTask;
import java.util.concurrent.RecursiveTask;
class MatrixMultiplication extends RecursiveTask<Matrix> {
/** The fork threshold. */
private static final int FORK_THRESHOLD = 1024; // Find a good threshold
/** The first matrix to multiply with. */
private final Matrix m1;
/** The second matrix to multiply with. */
private final Matrix m2;
/** The starting row of m1. */
private final int m1Row;
/** The starting col of m1. */
private final int m1Col;
/** The starting row of m2. */
private final int m2Row;
/** The starting col of m2. */
private final int m2Col;
/**
* The dimension of the input (sub)-matrices and the size of the output
* matrix.
*/
private int dimension;
/**
* A constructor for the Matrix Multiplication class.
*
* @param m1 The matrix to multiply with.
* @param m2 The matrix to multiply with.
* @param m1Row The starting row of m1.
* @param m1Col The starting col of m1.
* @param m2Row The starting row of m2.
* @param m2Col The starting col of m2.
* @param dimension The dimension of the input (sub)-matrices and the size
* of the output matrix.
*/
MatrixMultiplication(Matrix m1, Matrix m2, int m1Row, int m1Col, int m2Row,
int m2Col, int dimension) {
this.m1 = m1;
this.m2 = m2;
this.m1Row = m1Row;
this.m1Col = m1Col;
this.m2Row = m2Row;
this.m2Col = m2Col;
this.dimension = dimension;
}
public static final int THRESHOLD = 2;
@Override
public Matrix compute() {
// Modify this
if (dimension <= THRESHOLD) {
return Matrix.nonRecursiveMultiply(m1, m2, m1Row, m1Col, m2Row, m2Col, dimension);
}
int size = dimension / 2;
Matrix result = new Matrix(dimension);
ForkJoinTask<Matrix> a11b11Fork = new MatrixMultiplication(m1, m2, m1Row, m1Col, m2Row, m2Col, size).fork();
ForkJoinTask<Matrix> a12b21Fork = new MatrixMultiplication(m1, m2, m1Row, m1Col + size, m2Row + size, m2Col, size)
.fork();
ForkJoinTask<Matrix> a11b12Fork = new MatrixMultiplication(m1, m2, m1Row, m1Col, m2Row, m2Col + size, size).fork();
ForkJoinTask<Matrix> a12b22Fork = new MatrixMultiplication(m1, m2, m1Row, m1Col + size, m2Row + size, m2Col + size,
size).fork();
ForkJoinTask<Matrix> a21b11Fork = new MatrixMultiplication(m1, m2, m1Row + size, m1Col, m2Row, m2Col, size).fork();
ForkJoinTask<Matrix> a22b21Fork = new MatrixMultiplication(m1, m2, m1Row + size, m1Col + size, m2Row + size, m2Col,
size).fork();
ForkJoinTask<Matrix> a21b12Fork = new MatrixMultiplication(m1, m2, m1Row + size, m1Col, m2Row, m2Col + size, size)
.fork();
ForkJoinTask<Matrix> a22b22Fork = new MatrixMultiplication(m1, m2, m1Row + size, m1Col + size, m2Row + size,
m2Col + size, size).fork();
Matrix a11b11 = a11b11Fork.join();
Matrix a12b21 = a12b21Fork.join();
for (int i = 0; i < size; i++) {
double[] m1m = a11b11.m[i];
double[] m2m = a12b21.m[i];
double[] r1m = result.m[i];
for (int j = 0; j < size; j++) {
r1m[j] = m1m[j] + m2m[j];
}
}
Matrix a11b12 = a11b12Fork.join();
Matrix a12b22 = a12b22Fork.join();
for (int i = 0; i < size; i++) {
double[] m1m = a11b12.m[i];
double[] m2m = a12b22.m[i];
double[] r1m = result.m[i];
for (int j = 0; j < size; j++) {
r1m[j + size] = m1m[j] + m2m[j];
}
}
Matrix a21b11 = a21b11Fork.join();
Matrix a22b21 = a22b21Fork.join();
for (int i = 0; i < size; i++) {
double[] m1m = a21b11.m[i];
double[] m2m = a22b21.m[i];
double[] r1m = result.m[i + size];
for (int j = 0; j < size; j++) {
r1m[j] = m1m[j] + m2m[j];
}
}
Matrix a21b12 = a21b12Fork.join();
Matrix a22b22 = a22b22Fork.join();
for (int i = 0; i < size; i++) {
double[] m1m = a21b12.m[i];
double[] m2m = a22b22.m[i];
double[] r1m = result.m[i + size];
for (int j = 0; j < size; j++) {
r1m[j + size] = m1m[j] + m2m[j];
}
}
return result;
}
}

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long