feat: idek anymore

This commit is contained in:
2024-04-29 12:45:46 +08:00
parent 215cde2d19
commit d294ac0e38
3 changed files with 417 additions and 98 deletions

View File

@@ -315,12 +315,12 @@
},
{
"cell_type": "code",
"execution_count": 72,
"execution_count": 10,
"id": "a44b7aa4",
"metadata": {
"ExecuteTime": {
"end_time": "2024-04-28T12:00:17.228662Z",
"start_time": "2024-04-28T12:00:17.209494Z"
"end_time": "2024-04-28T12:27:25.926991Z",
"start_time": "2024-04-28T12:27:25.917322Z"
}
},
"outputs": [],
@@ -406,8 +406,8 @@
" def fit(self, X, y):\n",
" X, y = process_data(X, y)\n",
" train_dataset = torch.utils.data.TensorDataset(X, y)\n",
" train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True)\n",
" train(self.model, self.criterion, self.optimizer, train_loader, 10)\n",
" train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)\n",
" train(self.model, self.criterion, self.optimizer, train_loader, 20)\n",
"\n",
" def predict(self, X):\n",
" self.model.eval()\n",
@@ -438,12 +438,12 @@
},
{
"cell_type": "code",
"execution_count": 73,
"execution_count": 2,
"id": "4f4dd489",
"metadata": {
"ExecuteTime": {
"end_time": "2024-04-28T12:00:19.363096Z",
"start_time": "2024-04-28T12:00:19.352424Z"
"end_time": "2024-04-28T12:09:46.115322Z",
"start_time": "2024-04-28T12:09:45.631452Z"
}
},
"outputs": [],
@@ -458,12 +458,12 @@
},
{
"cell_type": "code",
"execution_count": 74,
"execution_count": 3,
"id": "3064e0ff",
"metadata": {
"ExecuteTime": {
"end_time": "2024-04-28T12:00:20.265060Z",
"start_time": "2024-04-28T12:00:20.234748Z"
"end_time": "2024-04-28T12:09:47.340881Z",
"start_time": "2024-04-28T12:09:47.317719Z"
}
},
"outputs": [],
@@ -477,12 +477,12 @@
},
{
"cell_type": "code",
"execution_count": 75,
"execution_count": 12,
"id": "27c9fd10",
"metadata": {
"ExecuteTime": {
"end_time": "2024-04-28T12:00:37.185569Z",
"start_time": "2024-04-28T12:00:22.239036Z"
"end_time": "2024-04-28T12:28:29.269402Z",
"start_time": "2024-04-28T12:28:02.494602Z"
}
},
"outputs": [
@@ -490,19 +490,29 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 0, Loss: 0.7495917081832886\n",
"Epoch 1, Loss: 0.42713749408721924\n",
"Epoch 2, Loss: 0.21424821019172668\n",
"Epoch 3, Loss: 0.02086367830634117\n",
"Epoch 4, Loss: 0.005386564414948225\n",
"Epoch 5, Loss: 0.00319607718847692\n",
"Epoch 6, Loss: 0.007663913071155548\n",
"Epoch 7, Loss: 0.003004509722813964\n",
"Epoch 8, Loss: 0.0044013322331011295\n",
"Epoch 9, Loss: 0.0016760551370680332\n",
"F1 Score (macro): 0.75\n",
"CPU times: user 57.8 s, sys: 1min 12s, total: 2min 10s\n",
"Wall time: 14.9 s\n"
"Epoch 0, Loss: 0.5610745549201965\n",
"Epoch 1, Loss: 0.22023160755634308\n",
"Epoch 2, Loss: 0.03679683431982994\n",
"Epoch 3, Loss: 0.009054183959960938\n",
"Epoch 4, Loss: 0.0021134500857442617\n",
"Epoch 5, Loss: 0.002705463906750083\n",
"Epoch 6, Loss: 0.0045105633325874805\n",
"Epoch 7, Loss: 0.001958428416401148\n",
"Epoch 8, Loss: 0.0010891605634242296\n",
"Epoch 9, Loss: 0.0010821395553648472\n",
"Epoch 10, Loss: 0.0007317279814742506\n",
"Epoch 11, Loss: 0.0006673489115200937\n",
"Epoch 12, Loss: 0.00047141974209807813\n",
"Epoch 13, Loss: 0.00024128056247718632\n",
"Epoch 14, Loss: 0.0003150832490064204\n",
"Epoch 15, Loss: 0.0004005862574558705\n",
"Epoch 16, Loss: 0.00024190203112084419\n",
"Epoch 17, Loss: 0.0004451812419574708\n",
"Epoch 18, Loss: 0.000376795680494979\n",
"Epoch 19, Loss: 0.0003616203321143985\n",
"F1 Score (macro): 0.65\n",
"CPU times: user 2min 33s, sys: 255 ms, total: 2min 34s\n",
"Wall time: 26.8 s\n"
]
}
],