feat: 0.7 on COURSEMO

This commit is contained in:
2024-04-28 20:07:30 +08:00
parent ded1032825
commit 2d95b112c5
2 changed files with 102 additions and 159 deletions

View File

@@ -506,12 +506,12 @@
},
{
"cell_type": "code",
"execution_count": 230,
"execution_count": 238,
"id": "d8dffd7d",
"metadata": {
"ExecuteTime": {
"end_time": "2024-04-28T07:57:09.790124Z",
"start_time": "2024-04-28T07:57:09.780591Z"
"end_time": "2024-04-28T08:00:54.037178Z",
"start_time": "2024-04-28T08:00:54.027410Z"
}
},
"outputs": [],
@@ -524,7 +524,6 @@
" self.mp = nn.AvgPool3d(2)\n",
" self.relu = nn.LeakyReLU()\n",
" self.fc1 = nn.Linear(3888, 6)\n",
" self.fc2 = nn.Linear(128, 6)\n",
" self.flatten = nn.Flatten()\n",
" def forward(self, x):\n",
" x = self.conv1(x)\n",
@@ -534,8 +533,6 @@
" # print(x.shape)\n",
" \n",
" x = x.view(-1, 3888)\n",
" x = self.fc1(x)\n",
" # x = self.fc2(x)\n",
" return x\n",
" \n",
"def train(model, criterion, optimizer, loader, epochs = 10):\n",
@@ -548,8 +545,7 @@
" optimizer.step()\n",
" print(f'Epoch {epoch}, Loss: {loss.item()}')\n",
" return model\n",
"def process_data(X, y):\n",
" y = np.array(y)\n",
"def process_X(X):\n",
" X = np.array([video[:6] for video in X])\n",
" tensor_videos = torch.tensor(X, dtype=torch.float32)\n",
" # Clip values to 0 and 255\n",
@@ -558,6 +554,11 @@
" for i in range(tensor_videos.shape[0]):\n",
" for j in range(tensor_videos.shape[1]):\n",
" tensor_videos[i][j][torch.isnan(tensor_videos[i][j])] = torch.mean(tensor_videos[i][j][~torch.isnan(tensor_videos[i][j])])\n",
" return tensor_videos\n",
" \n",
"def process_data(X, y):\n",
" y = np.array(y)\n",
" tensor_videos = process_X(X)\n",
" # Undersample the data for each of the 6 classes. Select max of 300 samples for each class\n",
" # Very much generated with the assitance of chatGPT with some modifications\n",
" # Get the indices of each class\n",
@@ -607,12 +608,12 @@
},
{
"cell_type": "code",
"execution_count": 217,
"execution_count": 239,
"id": "9245ab47",
"metadata": {
"ExecuteTime": {
"end_time": "2024-04-28T07:55:53.563103Z",
"start_time": "2024-04-28T07:55:53.544134Z"
"end_time": "2024-04-28T08:00:56.273946Z",
"start_time": "2024-04-28T08:00:56.253771Z"
}
},
"outputs": [],
@@ -640,17 +641,17 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 0, Loss: 4.225716590881348\n",
"Epoch 1, Loss: 0.9198675155639648\n",
"Epoch 2, Loss: 1.7365752458572388\n",
"Epoch 3, Loss: 0.4570190906524658\n",
"Epoch 4, Loss: 0.11014104634523392\n",
"Epoch 5, Loss: 0.24420055747032166\n",
"Epoch 6, Loss: 0.03079795092344284\n",
"Epoch 7, Loss: 0.07790327817201614\n",
"Epoch 8, Loss: 0.07603466510772705\n",
"Epoch 9, Loss: 0.04154537618160248\n",
"F1 Score (macro): 0.51\n"
"Epoch 0, Loss: 85.83575439453125\n",
"Epoch 1, Loss: 43.13077926635742\n",
"Epoch 2, Loss: 13.879751205444336\n",
"Epoch 3, Loss: 3.084989070892334\n",
"Epoch 4, Loss: 5.557327747344971\n",
"Epoch 5, Loss: 3.1260528564453125\n",
"Epoch 6, Loss: 3.4430527687072754\n",
"Epoch 7, Loss: 5.166628837585449\n",
"Epoch 8, Loss: 4.4976654052734375\n",
"Epoch 9, Loss: 5.530020236968994\n",
"F1 Score (macro): 0.02\n"
]
}
],
@@ -666,12 +667,12 @@
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-04-28T07:57:38.644155Z",
"start_time": "2024-04-28T07:57:35.958882Z"
"end_time": "2024-04-28T08:01:04.071319Z",
"start_time": "2024-04-28T08:01:01.436939Z"
}
},
"id": "abb2d957f4a15bd2",
"execution_count": 235
"execution_count": 241
},
{
"cell_type": "code",