feat: 0.39 on local

This commit is contained in:
2024-04-28 17:54:48 +08:00
parent 9702f0e3eb
commit 4983c0be68

View File

@@ -23,37 +23,29 @@ class CNN3D(nn.Module):
super(CNN3D, self).__init__() super(CNN3D, self).__init__()
self.conv1 = nn.Conv3d(1, 16, 2, 1, 2) self.conv1 = nn.Conv3d(1, 16, 2, 1, 2)
self.batchnorm3d = nn.BatchNorm3d(16) self.batchnorm3d = nn.BatchNorm3d(16)
self.batchnorm1d = nn.BatchNorm1d(64)
self.dropout = nn.Dropout(0.5) self.dropout = nn.Dropout(0.5)
self.mp3d = nn.AvgPool3d(2) self.mp3d = nn.AvgPool3d(2)
self.relu = nn.ReLU() self.relu = nn.LeakyReLU()
self.lstm = nn.LSTM(5184, 64, 1, batch_first=True) self.lstm = nn.LSTM(5184, 64, 1, batch_first=True)
self.fc2 = nn.Linear(64, 6) self.fc2 = nn.Linear(64, 6)
def forward(self, x): def forward(self, x):
x = self.conv1(x) x = self.conv1(x)
x = self.mp3d(x)
x = self.batchnorm3d(x)
x = self.relu(x) x = self.relu(x)
x = self.batchnorm3d(x)
x = self.mp3d(x)
x = self.dropout(x) x = self.dropout(x)
x = x.view(-1, 5184) x = x.view(-1, 5184)
# print(x.shape)
x, _ = self.lstm(x) x, _ = self.lstm(x)
# print(x.shape)
x = self.batchnorm1d(x)
x = self.relu(x) x = self.relu(x)
x = self.dropout(x)
x = self.fc2(x) x = self.fc2(x)
return torch.softmax(x, dim=1) return torch.softmax(x, dim=1)
def train(model, criterion, optimizer, loader, epochs=10):
def train(model, criterion, optimizer, loader, epochs=20):
for epoch in range(epochs): for epoch in range(epochs):
for idx, (inputs, labels) in enumerate(loader): for idx, (inputs, labels) in enumerate(loader):
optimizer.zero_grad() optimizer.zero_grad()
@@ -82,7 +74,7 @@ def process_data(X, y):
# Get the indices of each class # Get the indices of each class
indices = [np.argwhere(y == i).squeeze(1) for i in range(6)] indices = [np.argwhere(y == i).squeeze(1) for i in range(6)]
# Get the number of samples to take for each class # Get the number of samples to take for each class
num_samples_to_take = 300 num_samples_to_take = 1500
# Get the indices of the samples to take # Get the indices of the samples to take
indices_to_take = [np.random.choice(indices[i], num_samples_to_take, replace=True) for i in range(6)] indices_to_take = [np.random.choice(indices[i], num_samples_to_take, replace=True) for i in range(6)]
# Concatenate the indices # Concatenate the indices
@@ -102,8 +94,8 @@ class Model():
def fit(self, X, y): def fit(self, X, y):
X, y = process_data(X, y) X, y = process_data(X, y)
train_dataset = torch.utils.data.TensorDataset(X, y) train_dataset = torch.utils.data.TensorDataset(X, y)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
train(self.model, self.criterion, self.optimizer, train_loader, 10) train(self.model, self.criterion, self.optimizer, train_loader, 5)
def predict(self, X): def predict(self, X):
self.model.eval() self.model.eval()