feat: idek anymore

This commit is contained in:
Yadunand Prem 2024-04-29 12:45:46 +08:00
parent 215cde2d19
commit d294ac0e38
No known key found for this signature in database
3 changed files with 417 additions and 98 deletions

View File

@ -16,6 +16,7 @@ from torch import nn
import numpy as np
import torch
import os
from torchvision.transforms.functional import equalize
class CNN3D(nn.Module):
def __init__(self, hidden_size=32, dropout=0.0):
@ -27,7 +28,7 @@ class CNN3D(nn.Module):
self.maxpool = nn.MaxPool3d(kernel_size=2, stride=2)
self.fc1 = nn.Linear(hidden_size*32, 256) # Calculate input size based on output from conv3
self.fc2 = nn.Linear(256, 6)
self.dropout = nn.Dropout(dropout)
# self.dropout = nn.Dropout(dropout)
def forward(self, x):
x = self.conv1(x)
@ -37,7 +38,7 @@ class CNN3D(nn.Module):
x = self.conv2(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.dropout(x)
# x = self.dropout(x)
x = x.view(x.size(0), -1) # Flatten features for fully connected layers
x = self.fc1(x)
@ -56,17 +57,16 @@ def train(model, criterion, optimizer, loader, epochs=5):
print(f'Epoch {epoch}, Loss: {loss.item()}')
return model
class Model():
def __init__(self, batch_size=8,lr=0.001,epochs=10, dropout=0.0, hidden_size=32):
def __init__(self, batch_size=64,lr=0.001,epochs=5, dropout=0.0, hidden_size=32, n_samples=900):
print(batch_size, epochs, lr, dropout, hidden_size, n_samples)
self.batch_size = batch_size
self.lr = lr
self.epochs = epochs
self.model = CNN3D(dropout=dropout, hidden_size=hidden_size)
self.criterion = nn.CrossEntropyLoss()
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr)
self.n_samples = n_samples
def fit(self, X, y):
X, y = self.process_data(X, y)
@ -81,31 +81,25 @@ class Model():
tensor_videos = torch.tensor(X, dtype=torch.float32)
# Clip values to 0 and 255
tensor_videos = np.clip(tensor_videos, 0, 255)
# TEMP
threshold = 180
tensor_videos[tensor_videos > threshold] = 255
tensor_videos[tensor_videos < threshold] = 0
# END TEMP
# Replace NaNs in each frame, with the average of the frame. This was generated with GPT
for i in range(tensor_videos.shape[0]):
for j in range(tensor_videos.shape[1]):
tensor_videos[i][j][torch.isnan(tensor_videos[i][j])] = torch.mean(
tensor_videos[i][j][~torch.isnan(tensor_videos[i][j])])
X = torch.Tensor(tensor_videos.unsqueeze(1))
result = self.model(X)
# tensor_videos = torch.Tensor(tensor_videos).to(torch.uint8).reshape(-1, 1, 16, 16)
# tensor_videos = equalize(tensor_videos).float().reshape(-1, 1, 6, 16, 16)
tensor_videos = torch.Tensor(tensor_videos).reshape(-1, 1, 6, 16, 16)
# some funky code to make the features more prominent
result = self.model(tensor_videos)
return torch.max(result, dim=1)[1].numpy()
def process_data(self, X, y, n_samples=600):
def process_data(self, X, y):
y = np.array(y)
X = np.array([video[:6] for video in X])
tensor_videos = torch.tensor(X, dtype=torch.float32)
# Clip values to 0 and 255
tensor_videos = np.clip(tensor_videos, 0, 255)
# TEMP
threshold = 180
tensor_videos[tensor_videos > threshold] = 255
tensor_videos[tensor_videos < threshold] = 0
# END TEMP
# Replace NaNs in each frame, with the average of the frame. This was generated with GPT
for i in range(tensor_videos.shape[0]):
@ -118,13 +112,19 @@ class Model():
indices = [np.argwhere(y == i).squeeze(1) for i in range(6)]
# Get the number of samples to take for each class
# Get the indices of the samples to take
indices_to_take = [np.random.choice(indices[i], n_samples, replace=True) for i in range(6)]
indices_to_take = [np.random.choice(indices[i], self.n_samples, replace=True) for i in range(6)]
# Concatenate the indices
indices_to_take = np.concatenate(indices_to_take)
# Select the samples
tensor_videos = tensor_videos[indices_to_take].unsqueeze(1)
tensor_videos = tensor_videos[indices_to_take]
tensor_videos = torch.Tensor(tensor_videos).reshape(-1, 1, 6, 16, 16)
# Reshape the tensor to int for image processing
# tensor_videos = torch.Tensor(tensor_videos).to(torch.uint8).reshape(-1, 1, 16, 16)
# tensor_videos = equalize(tensor_videos).float().reshape(-1, 1, 6, 16, 16)
y = y[indices_to_take]
return torch.Tensor(tensor_videos), torch.Tensor(y).long()
return tensor_videos, torch.Tensor(y).long()
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1)

View File

@ -315,12 +315,12 @@
},
{
"cell_type": "code",
"execution_count": 72,
"execution_count": 10,
"id": "a44b7aa4",
"metadata": {
"ExecuteTime": {
"end_time": "2024-04-28T12:00:17.228662Z",
"start_time": "2024-04-28T12:00:17.209494Z"
"end_time": "2024-04-28T12:27:25.926991Z",
"start_time": "2024-04-28T12:27:25.917322Z"
}
},
"outputs": [],
@ -406,8 +406,8 @@
" def fit(self, X, y):\n",
" X, y = process_data(X, y)\n",
" train_dataset = torch.utils.data.TensorDataset(X, y)\n",
" train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True)\n",
" train(self.model, self.criterion, self.optimizer, train_loader, 10)\n",
" train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)\n",
" train(self.model, self.criterion, self.optimizer, train_loader, 20)\n",
"\n",
" def predict(self, X):\n",
" self.model.eval()\n",
@ -438,12 +438,12 @@
},
{
"cell_type": "code",
"execution_count": 73,
"execution_count": 2,
"id": "4f4dd489",
"metadata": {
"ExecuteTime": {
"end_time": "2024-04-28T12:00:19.363096Z",
"start_time": "2024-04-28T12:00:19.352424Z"
"end_time": "2024-04-28T12:09:46.115322Z",
"start_time": "2024-04-28T12:09:45.631452Z"
}
},
"outputs": [],
@ -458,12 +458,12 @@
},
{
"cell_type": "code",
"execution_count": 74,
"execution_count": 3,
"id": "3064e0ff",
"metadata": {
"ExecuteTime": {
"end_time": "2024-04-28T12:00:20.265060Z",
"start_time": "2024-04-28T12:00:20.234748Z"
"end_time": "2024-04-28T12:09:47.340881Z",
"start_time": "2024-04-28T12:09:47.317719Z"
}
},
"outputs": [],
@ -477,12 +477,12 @@
},
{
"cell_type": "code",
"execution_count": 75,
"execution_count": 12,
"id": "27c9fd10",
"metadata": {
"ExecuteTime": {
"end_time": "2024-04-28T12:00:37.185569Z",
"start_time": "2024-04-28T12:00:22.239036Z"
"end_time": "2024-04-28T12:28:29.269402Z",
"start_time": "2024-04-28T12:28:02.494602Z"
}
},
"outputs": [
@ -490,19 +490,29 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 0, Loss: 0.7495917081832886\n",
"Epoch 1, Loss: 0.42713749408721924\n",
"Epoch 2, Loss: 0.21424821019172668\n",
"Epoch 3, Loss: 0.02086367830634117\n",
"Epoch 4, Loss: 0.005386564414948225\n",
"Epoch 5, Loss: 0.00319607718847692\n",
"Epoch 6, Loss: 0.007663913071155548\n",
"Epoch 7, Loss: 0.003004509722813964\n",
"Epoch 8, Loss: 0.0044013322331011295\n",
"Epoch 9, Loss: 0.0016760551370680332\n",
"F1 Score (macro): 0.75\n",
"CPU times: user 57.8 s, sys: 1min 12s, total: 2min 10s\n",
"Wall time: 14.9 s\n"
"Epoch 0, Loss: 0.5610745549201965\n",
"Epoch 1, Loss: 0.22023160755634308\n",
"Epoch 2, Loss: 0.03679683431982994\n",
"Epoch 3, Loss: 0.009054183959960938\n",
"Epoch 4, Loss: 0.0021134500857442617\n",
"Epoch 5, Loss: 0.002705463906750083\n",
"Epoch 6, Loss: 0.0045105633325874805\n",
"Epoch 7, Loss: 0.001958428416401148\n",
"Epoch 8, Loss: 0.0010891605634242296\n",
"Epoch 9, Loss: 0.0010821395553648472\n",
"Epoch 10, Loss: 0.0007317279814742506\n",
"Epoch 11, Loss: 0.0006673489115200937\n",
"Epoch 12, Loss: 0.00047141974209807813\n",
"Epoch 13, Loss: 0.00024128056247718632\n",
"Epoch 14, Loss: 0.0003150832490064204\n",
"Epoch 15, Loss: 0.0004005862574558705\n",
"Epoch 16, Loss: 0.00024190203112084419\n",
"Epoch 17, Loss: 0.0004451812419574708\n",
"Epoch 18, Loss: 0.000376795680494979\n",
"Epoch 19, Loss: 0.0003616203321143985\n",
"F1 Score (macro): 0.65\n",
"CPU times: user 2min 33s, sys: 255 ms, total: 2min 34s\n",
"Wall time: 26.8 s\n"
]
}
],

File diff suppressed because one or more lines are too long