maybe good?

This commit is contained in:
2024-04-28 15:58:30 +08:00
parent 1312e694c3
commit d2e87aec97
21 changed files with 3097 additions and 700 deletions

View File

@@ -23,21 +23,26 @@
"execution_count": 1,
"id": "adfd1c67",
"metadata": {
"ExecuteTime": {
"end_time": "2024-04-11T04:02:54.316718Z",
"start_time": "2024-04-11T04:02:53.604913Z"
},
"collapsed": false,
"jupyter": {
"outputs_hidden": false
"ExecuteTime": {
"end_time": "2024-04-27T17:03:32.690848Z",
"start_time": "2024-04-27T17:03:30.945851Z"
}
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/tmp/ipykernel_95268/3863716717.py:17: UserWarning: 'has_mps' is deprecated, please use 'torch.backends.mps.is_built()'\n",
" device = torch.device(\"mps\") if torch.has_mps else torch.device(\"cpu\")\n",
"/tmp/ipykernel_95268/3863716717.py:18: UserWarning: 'has_mps' is deprecated, please use 'torch.backends.mps.is_built()'\n",
" torch.has_mps\n"
]
},
{
"data": {
"text/plain": [
"True"
]
"text/plain": "False"
},
"execution_count": 1,
"metadata": {},
@@ -139,22 +144,19 @@
"execution_count": null,
"id": "a63a577557da6e87",
"metadata": {
"collapsed": false,
"jupyter": {
"outputs_hidden": false
}
"collapsed": false
},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 2,
"id": "7873ce10",
"metadata": {
"ExecuteTime": {
"end_time": "2024-04-03T08:51:30.782923Z",
"start_time": "2024-04-03T08:51:30.778415Z"
"end_time": "2024-04-27T17:03:40.731264Z",
"start_time": "2024-04-27T17:03:40.726572Z"
}
},
"outputs": [],
@@ -417,12 +419,12 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 4,
"id": "c78ee5e1",
"metadata": {
"ExecuteTime": {
"end_time": "2024-04-11T04:03:03.689966Z",
"start_time": "2024-04-11T04:03:03.667782Z"
"end_time": "2024-04-27T17:04:36.032494Z",
"start_time": "2024-04-27T17:04:35.981931Z"
}
},
"outputs": [],
@@ -477,12 +479,12 @@
},
{
"cell_type": "code",
"execution_count": 86,
"execution_count": 5,
"id": "4e0fdf18",
"metadata": {
"ExecuteTime": {
"end_time": "2024-04-07T03:15:42.540657Z",
"start_time": "2024-04-07T03:15:42.436629Z"
"end_time": "2024-04-27T17:04:39.560907Z",
"start_time": "2024-04-27T17:04:39.424591Z"
}
},
"outputs": [
@@ -496,10 +498,8 @@
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYUAAAGFCAYAAAASI+9IAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy81sbWrAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAI1UlEQVR4nO3cr2uVfx/H8evczm8YTJSDwWEzajE5jCZ1oCzIBH+AMpBFwSKiCHNBsBks/gVOsGwWwxBElrZkGEuiIgxN0yXx3OG+71e6w3lfut+PRz4vrmt65pNP8NPp9Xq9BgCapvnXVr8AANuHKAAQogBAiAIAIQoAhCgAEKIAQIgCADHQ7wc7nc5GvgcAG6yf/6vspABAiAIAIQoAhCgAEKIAQIgCACEKAIQoABCiAECIAgAhCgCEKAAQogBAiAIAIQoAhCgAEKIAQIgCACEKAIQoABCiAECIAgAhCgCEKAAQogBAiAIAIQoAhCgAEKIAQIgCACEKAIQoABCiAECIAgAhCgCEKAAQogBAiAIAIQoAhCgAEKIAQIgCACEKAIQoABCiAECIAgAhCgCEKAAQogBAiAIAIQoAhCgAEANb/QLsHVevXm21O3ny5F9+k/9vYmKivDlw4EB58/nz5/KmaZpmenq6vHn+/Hl58+vXr/KG3cNJAYAQBQBCFAAIUQAgRAGAEAUAQhQACFEAIEQBgBAFAEIUAAhRACA6vV6v19cHO52Nfhd2kOHh4fLm06dPrZ7V51d0S7T5vdjMn+fRo0flzcOHD//+i7At9PPdc1IAIEQBgBAFAEIUAAhRACBEAYAQBQBCFAAIUQAgRAGAEAUAQhQAiIGtfgF2pjaXus3MzLR61tzcXHmzsrLS6lmb4dq1a612t27dKm+OHTvW6lnsXU4KAIQoABCiAECIAgAhCgCEKAAQogBAiAIAIQoAhCgAEKIAQIgCANHp9XmzWafT2eh3gT1hZGSk1e7du3flzY8fP8qbgwcPljfsDP38c++kAECIAgAhCgCEKAAQogBAiAIAIQoAhCgAEKIAQIgCACEKAIQoABCiAEAMbPULwE7W5kbRe/futXpWm5uKr1+/3upZ7F1OCgCEKAAQogBAiAIAIQoAhCgAEKIAQIgCACEKAIQoABCiAECIAgDhQjz4r+Hh4fLmyZMn5c25c+fKm6Zpmm/fvpU3CwsLrZ7F3uWkAECIAgAhCgCEKAAQogBAiAIAIQoAhCgAEKIAQIgCACEKAIQoABAuxGPb63a75c2ZM2fKm/v375c3x48fL2/aXGzXNE1z6dKl8mZ1dbXVs9i7nBQACFEAIEQBgBAFAEIUAAhRACBEAYAQBQBCFAAIUQAgRAGAEAUAwoV4bJrbt2+32o2NjZU3p0+fbvWszTA0NNRqNzo6Wt4sLS2VN2tra+UNu4eTAgAhCgCEKAAQogBAiAIAIQoAhCgAEKIAQIgCACEKAIQoABCiAECIAgDR6fV6vb4+2Ols9LuwRQYHB8ublZWV8ubIkSPlTdM0TZ9f0S3R5vdiM3+excXF8ubOnTvlzdu3b8sbNl8/3z0nBQBCFAAIUQAgRAGAEAUAQhQACFEAIEQBgBAFAEIUAAhRACBEAYAY2OoXYOudOnWqvDl06FB50/YiuOXl5fLm2bNn5c3Xr1/Lm800Pj5e3oyOjpY3U1NT5c3Y2Fh58/379/KGjeekAECIAgAhCgCEKAAQogBAiAIAIQoAhCgAEKIAQIgCACEKAIQoABCdXp+3lHU6nY1+F3aQiYmJ8mb//v2tnvXixYvyxmVr//Hx48fy5ujRo+VNm8v6Xr58Wd7wZ/r5595JAYAQBQBCFAAIUQAgRAGAEAUAQhQACFEAIEQBgBAFAEIUAAhRACBciAe72KtXr8qbCxculDdv3rwpb86ePVve8GdciAdAiSgAEKIAQIgCACEKAIQoABCiAECIAgAhCgCEKAAQogBAiAIAMbDVLwBsnKWlpfLm4sWL5c3Q0FB5w/bkpABAiAIAIQoAhCgAEKIAQIgCACEKAIQoABCiAECIAgAhCgCEKAAQogBAuCUVdohut1veTE5Olje9Xq+8mZ2dLW/YnpwUAAhRACBEAYAQBQBCFAAIUQAgRAGAEAUAQhQACFEAIEQBgBAFAMKFeLDJTpw40Wo3MzNT3hw+fLi8WV9fL2/m5+fLG7YnJwUAQhQACFEAIEQBgBAFAEIUAAhRACBEAYAQBQBCFAAIUQAgRAGAcCEe/IEHDx6UN5OTk62e1eZyuzYuX75c3iwsLGzAm7AVnBQACFEAIEQBgBAFAEIUAAhRACBEAYAQBQBCFAAIUQAgRAGAEAUAwoV4RU+fPi1v/vnnn/Lm9evX5U3TNM3i4mJ5s7q6Wt50u93yZt++feVN0zTNkSNHypvx8fHyZmJiorwZGhoqb3q9XnnTNE2zvr5e3rS53G5ubq68YfdwUgAgRAGAEAUAQhQACFEAIEQBgBAFAEIUAAhRACBEAYAQBQBCFACITq/P27k6nc5Gv8uOcPfu3fJmamqqvGn7593msrX5+fnyZmRkpLwZHBwsb5qm/QVym6HN39Py8nKrZ928ebO8WVhYaPUsdqd+fpecFAAIUQAgRAGAEAUAQhQACFEAIEQBgBAFAEIUAAhRACBEAYAQBQBCFAAIt6QWdbvd8ubx48flzY0bN8qbptl9N4o2zfb+md6/f1/eTE5OtnrWhw8fWu3gf9ySCkCJKAAQogBAiAIAIQoAhCgAEKIAQIgCACEKAIQoABCiAECIAgDhQrxt6sqVK61258+fL2/Gx8dbPavq58+frXazs7Plze/fv8ub6enp8ubLly/lzdraWnkDf4ML8QAoEQUAQhQACFEAIEQBgBAFAEIUAAhRACBEAYAQBQBCFAAIUQAgXIgHsEe4EA+AElEAIEQBgBAFAEIUAAhRACBEAYAQBQBCFAAIUQAgRAGAEAUAQhQACFEAIEQBgBAFAEIUAAhRACBEAYAQBQBCFAAIUQAgRAGAEAUAQhQACFEAIEQBgBAFAEIUAAhRACBEAYAQBQBCFAAIUQAgRAGAEAUAQhQACFEAIEQBgBAFAEIUAAhRACBEAYAQBQBCFAAIUQAgRAGAEAUAQhQACFEAIEQBgBAFAGKg3w/2er2NfA8AtgEnBQBCFAAIUQAgRAGAEAUAQhQACFEAIEQBgBAFAOLfZHAnDT21VNIAAAAASUVORK5CYII=",
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
"text/plain": "<Figure size 640x480 with 1 Axes>",
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYUAAAGFCAYAAAASI+9IAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/H5lhTAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAJbklEQVR4nO3cMWieVR/G4ZM2CDHYGCrUCCo2GAex0EEq6FLo1kExDkZLzKAg1CVrwaK4BjsKSqBQiJkK3TNkqLabmkUMVixCFEEFNRmC4XUQbj6ww/t/vuRtTK9rzs15oE1/nsEz1Ov1eg0AWmuH7vYHALB/iAIAIQoAhCgAEKIAQIgCACEKAIQoABDD/f7g0NDQXn4HAHusn/9X2U0BgBAFAEIUAAhRACBEAYAQBQBCFAAIUQAgRAGAEAUAQhQACFEAIEQBgBAFAEIUAAhRACBEAYAQBQBCFAAIUQAgRAGAEAUAQhQACFEAIEQBgBi+2x/A7jp8+HB5c/Xq1fJmamqqvHn22WfLm9Za+/PPPzvtgDo3BQBCFAAIUQAgRAGAEAUAQhQACFEAIEQBgBAFAEIUAAhRACBEAYAQBQDCK6kHzPBw/Y/0yJEj5c2TTz5Z3oyMjJQ3rXklFQbJTQGAEAUAQhQACFEAIEQBgBAFAEIUAAhRACBEAYAQBQBCFAAIUQAgPIh3wJw6daq8OXHixB58CfBf5KYAQIgCACEKAIQoABCiAECIAgAhCgCEKAAQogBAiAIAIQoAhCgAEB7EO2Dee++98mZsbKy8WVtbK282NzfLG2Cw3BQACFEAIEQBgBAFAEIUAAhRACBEAYAQBQBCFAAIUQAgRAGAEAUAwoN4+9Tp06c77Z5//vld/pI7u3TpUnmztbW1B18C7CY3BQBCFAAIUQAgRAGAEAUAQhQACFEAIEQBgBAFAEIUAAhRACBEAYDwIN4+NTk52Wl3+PDhXf6SO1tfXx/IOcBguSkAEKIAQIgCACEKAIQoABCiAECIAgAhCgCEKAAQogBAiAIAIQoAhCgAEF5JpW1sbAxkA+x/bgoAhCgAEKIAQIgCACEKAIQoABCiAECIAgAhCgCEKAAQogBAiAIA4UG8fero0aMDO+uHH34YyIZ/vPDCC512s7Ozu/wld/bTTz+VN5988kl54+/Q/uSmAECIAgAhCgCEKAAQogBAiAIAIQoAhCgAEKIAQIgCACEKAIQoABAexNunXnvttYGd9d133w3srIPm/Pnz5c3Fixc7nTXIRxKr3nzzzfJmcXGx01nvvvtupx39cVMAIEQBgBAFAEIUAAhRACBEAYAQBQBCFAAIUQAgRAGAEAUAQhQAiKFer9fr6weHhvb6W/gfX331Vafd008/Xd6cO3euvFleXi5vDqLffvutvHnggQf24Evu7PPPPy9vHnroofJmamqqvNnc3CxvWmttbGys047W+vnn3k0BgBAFAEIUAAhRACBEAYAQBQBCFAAIUQAgRAGAEAUAQhQACFEAIEQBgBi+2x9wL+jyguQjjzzS6awuL09+8cUXnc7az0ZHR8ubpaWl8ubBBx8sb7755pvyprXW3nnnnfJmZWWlvJmYmChvurzG+thjj5U3rbU2Pz9f3ly6dKnTWfciNwUAQhQACFEAIEQBgBAFAEIUAAhRACBEAYAQBQBCFAAIUQAgRAGA8CDeADz66KPlzfj4eKezfv755/Km6wNt+9nc3Fx5c/bs2fJma2urvHn//ffLm9a6PW7XxY8//jiQTZffi9Zae+ONN8obD+L1z00BgBAFAEIUAAhRACBEAYAQBQBCFAAIUQAgRAGAEAUAQhQACFEAIDyIx4H00ksvDeSchYWF8ubTTz/dgy/ZPTMzM+XNM888swdfcmdLS0sDO+te5KYAQIgCACEKAIQoABCiAECIAgAhCgCEKAAQogBAiAIAIQoAhCgAEB7EG4DV1dXy5uuvv+501uOPP17ePPfcc+XNzZs3y5uuHn744fLm+PHje/Al//brr78O5JyuDh2q/3ffK6+8Ut6MjIyUNysrK+VNa619+OGHnXb0x00BgBAFAEIUAAhRACBEAYAQBQBCFAAIUQAgRAGAEAUAQhQACFEAIDyINwA7OzvlzZdfftnprKeeeqq8uXbtWnkzPT1d3ly/fr28aa210dHR8mZ8fLzTWVXHjh0byDldLSwslDcvvvhiebO+vl7eXLhwobxprbW//vqr047+uCkAEKIAQIgCACEKAIQoABCiAECIAgAhCgCEKAAQogBAiAIAIQoAhCgAEEO9Xq/X1w8ODe31t7ALfv/99/Lm/vvvL29u3LhR3pw/f768aa21tbW18uatt94qbz766KPyZnt7u7z54IMPypvWuv05zc3NlTcTExPlze3bt8ubkydPljetdfs7zj/6+efeTQGAEAUAQhQACFEAIEQBgBAFAEIUAAhRACBEAYAQBQBCFAAIUQAgPIh3wLz66qvlzeXLl8ub4eHh8uaXX34pb1pr7YknnihvdnZ2ypsrV66UN9PT0+VNn79yd83GxkZ5c+bMmfJmfX29vOH/40E8AEpEAYAQBQBCFAAIUQAgRAGAEAUAQhQACFEAIEQBgBAFAEIUAIj6q2bsa8vLy+VNl8cOP/744/Lm6NGj5U1rrd2+fbu8mZ2dLW9u3bpV3ux333//fXkzPz9f3njc7uBwUwAgRAGAEAUAQhQACFEAIEQBgBAFAEIUAAhRACBEAYAQBQBCFACIoV6v1+vrBzs8msbBNTMzU94sLi52Ouu+++7rtBuELr8Xff7K/cvq6mp58/bbb5c33377bXnDf0M/f/fcFAAIUQAgRAGAEAUAQhQACFEAIEQBgBAFAEIUAAhRACBEAYAQBQDCg3gMzOTkZKfdxYsXy5vXX3+901lVf/zxR3nz8ssvdzrrs88+K2+2t7c7ncXB5EE8AEpEAYAQBQBCFAAIUQAgRAGAEAUAQhQACFEAIEQBgBAFAEIUAAhRACC8kgpwj/BKKgAlogBAiAIAIQoAhCgAEKIAQIgCACEKAIQoABCiAECIAgAhCgCEKAAQogBAiAIAIQoAhCgAEKIAQIgCACEKAIQoABCiAECIAgAhCgCEKAAQogBAiAIAIQoAhCgAEKIAQIgCACEKAIQoABCiAECIAgAhCgCEKAAQogBAiAIAIQoAhCgAEKIAQIgCACEKAIQoABCiAECIAgAhCgCEKAAQogBAiAIAIQoAhCgAEKIAQIgCACEKAIQoABCiAECIAgAhCgCEKAAQogBAiAIAIQoAhCgAEKIAQIgCACEKAIQoABCiAECIAgAhCgCEKAAQogBAiAIAIQoAhCgAEKIAQAz3+4O9Xm8vvwOAfcBNAYAQBQBCFAAIUQAgRAGAEAUAQhQACFEAIEQBgPgbLvku86SCoJIAAAAASUVORK5CYII="
},
"metadata": {},
"output_type": "display_data"
@@ -508,7 +508,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Label: 3\n"
"Label: 6\n"
]
}
],
@@ -561,12 +561,12 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 6,
"id": "6dd33b95",
"metadata": {
"ExecuteTime": {
"end_time": "2024-04-11T04:03:51.459649Z",
"start_time": "2024-04-11T04:03:51.434499Z"
"end_time": "2024-04-27T17:04:45.792747Z",
"start_time": "2024-04-27T17:04:45.419030Z"
}
},
"outputs": [
@@ -808,43 +808,6 @@
"__Tip:__ Don't be worried if your model takes a while to train. Your mileage may also vary depending on your CPU. But if you would like to speed things up, you can consider making use of your device's GPU to parallelize the matrix computations."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "e95b10e89d6cfe43",
"metadata": {
"collapsed": false,
"jupyter": {
"outputs_hidden": false
}
},
"outputs": [],
"source": [
" self.conv = nn.Sequential(\n",
" nn.Conv2d(3, 32, (3, 3)),\n",
" nn.MaxPool2d((2, 2)),\n",
" nn.LeakyReLU(0.1),\n",
" nn.Conv2d(32, 64, (3, 3)),\n",
" nn.MaxPool2d((2, 2)),\n",
" nn.LeakyReLU(0.1),\n",
" )\n",
"\n",
" self.fc = nn.Sequential(\n",
" nn.Linear(64, 256),\n",
" nn.LeakyReLU(0.1),\n",
" nn.Linear(256, 128),\n",
" nn.LeakyReLU(0.1),\n",
" nn.Linear(128, classes)\n",
" )\n",
" \n",
" def forward(self, x):\n",
" # YOUR CODE HERE\n",
" x = self.conv(x)\n",
" x = x.view(x.shape[0], 64, 6*6).mean(2) # GAP do not remove this line\n",
" x = self.fc(x)\n",
" return x\n"
]
},
{
"cell_type": "code",
"execution_count": 8,
@@ -882,13 +845,13 @@
"evalue": "",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
"File \u001b[0;32m<timed exec>:33\u001b[0m\n",
"File \u001b[0;32m<timed exec>:21\u001b[0m, in \u001b[0;36mtrain_model\u001b[0;34m(loader, model, device)\u001b[0m\n",
"File \u001b[0;32m/opt/homebrew/anaconda3/envs/cs2109s-ay2223s1/lib/python3.9/site-packages/torch/_tensor.py:396\u001b[0m, in \u001b[0;36mTensor.backward\u001b[0;34m(self, gradient, retain_graph, create_graph, inputs)\u001b[0m\n\u001b[1;32m 387\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m has_torch_function_unary(\u001b[38;5;28mself\u001b[39m):\n\u001b[1;32m 388\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m handle_torch_function(\n\u001b[1;32m 389\u001b[0m Tensor\u001b[38;5;241m.\u001b[39mbackward,\n\u001b[1;32m 390\u001b[0m (\u001b[38;5;28mself\u001b[39m,),\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 394\u001b[0m create_graph\u001b[38;5;241m=\u001b[39mcreate_graph,\n\u001b[1;32m 395\u001b[0m inputs\u001b[38;5;241m=\u001b[39minputs)\n\u001b[0;32m--> 396\u001b[0m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mautograd\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbackward\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mgradient\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mretain_graph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcreate_graph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minputs\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m/opt/homebrew/anaconda3/envs/cs2109s-ay2223s1/lib/python3.9/site-packages/torch/autograd/__init__.py:173\u001b[0m, in \u001b[0;36mbackward\u001b[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)\u001b[0m\n\u001b[1;32m 168\u001b[0m retain_graph \u001b[38;5;241m=\u001b[39m create_graph\n\u001b[1;32m 170\u001b[0m \u001b[38;5;66;03m# The reason we repeat same the comment below is that\u001b[39;00m\n\u001b[1;32m 171\u001b[0m \u001b[38;5;66;03m# some Python versions print out the first line of a multi-line function\u001b[39;00m\n\u001b[1;32m 172\u001b[0m \u001b[38;5;66;03m# calls in the traceback and some print out the last line\u001b[39;00m\n\u001b[0;32m--> 173\u001b[0m \u001b[43mVariable\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_execution_engine\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun_backward\u001b[49m\u001b[43m(\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;66;43;03m# Calls into the C++ engine to run the backward pass\u001b[39;49;00m\n\u001b[1;32m 174\u001b[0m \u001b[43m \u001b[49m\u001b[43mtensors\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mgrad_tensors_\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mretain_graph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcreate_graph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 175\u001b[0m \u001b[43m \u001b[49m\u001b[43mallow_unreachable\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43maccumulate_grad\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m\n",
"\u001b[0;31mKeyboardInterrupt\u001b[0m: "
"\u001B[0;31m---------------------------------------------------------------------------\u001B[0m",
"\u001B[0;31mKeyboardInterrupt\u001B[0m Traceback (most recent call last)",
"File \u001B[0;32m<timed exec>:33\u001B[0m\n",
"File \u001B[0;32m<timed exec>:21\u001B[0m, in \u001B[0;36mtrain_model\u001B[0;34m(loader, model, device)\u001B[0m\n",
"File \u001B[0;32m/opt/homebrew/anaconda3/envs/cs2109s-ay2223s1/lib/python3.9/site-packages/torch/_tensor.py:396\u001B[0m, in \u001B[0;36mTensor.backward\u001B[0;34m(self, gradient, retain_graph, create_graph, inputs)\u001B[0m\n\u001B[1;32m 387\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m has_torch_function_unary(\u001B[38;5;28mself\u001B[39m):\n\u001B[1;32m 388\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m handle_torch_function(\n\u001B[1;32m 389\u001B[0m Tensor\u001B[38;5;241m.\u001B[39mbackward,\n\u001B[1;32m 390\u001B[0m (\u001B[38;5;28mself\u001B[39m,),\n\u001B[0;32m (...)\u001B[0m\n\u001B[1;32m 394\u001B[0m create_graph\u001B[38;5;241m=\u001B[39mcreate_graph,\n\u001B[1;32m 395\u001B[0m inputs\u001B[38;5;241m=\u001B[39minputs)\n\u001B[0;32m--> 396\u001B[0m \u001B[43mtorch\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mautograd\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mbackward\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;28;43mself\u001B[39;49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mgradient\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mretain_graph\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mcreate_graph\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43minputs\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43minputs\u001B[49m\u001B[43m)\u001B[49m\n",
"File \u001B[0;32m/opt/homebrew/anaconda3/envs/cs2109s-ay2223s1/lib/python3.9/site-packages/torch/autograd/__init__.py:173\u001B[0m, in \u001B[0;36mbackward\u001B[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)\u001B[0m\n\u001B[1;32m 168\u001B[0m retain_graph \u001B[38;5;241m=\u001B[39m create_graph\n\u001B[1;32m 170\u001B[0m \u001B[38;5;66;03m# The reason we repeat same the comment below is that\u001B[39;00m\n\u001B[1;32m 171\u001B[0m \u001B[38;5;66;03m# some Python versions print out the first line of a multi-line function\u001B[39;00m\n\u001B[1;32m 172\u001B[0m \u001B[38;5;66;03m# calls in the traceback and some print out the last line\u001B[39;00m\n\u001B[0;32m--> 173\u001B[0m \u001B[43mVariable\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43m_execution_engine\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mrun_backward\u001B[49m\u001B[43m(\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;66;43;03m# Calls into the C++ engine to run the backward pass\u001B[39;49;00m\n\u001B[1;32m 174\u001B[0m \u001B[43m \u001B[49m\u001B[43mtensors\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mgrad_tensors_\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mretain_graph\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mcreate_graph\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43minputs\u001B[49m\u001B[43m,\u001B[49m\n\u001B[1;32m 175\u001B[0m \u001B[43m \u001B[49m\u001B[43mallow_unreachable\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[38;5;28;43;01mTrue\u001B[39;49;00m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43maccumulate_grad\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[38;5;28;43;01mTrue\u001B[39;49;00m\u001B[43m)\u001B[49m\n",
"\u001B[0;31mKeyboardInterrupt\u001B[0m: "
]
}
],
@@ -1444,10 +1407,7 @@
"execution_count": 1,
"id": "bd138177d2e4877",
"metadata": {
"collapsed": false,
"jupyter": {
"outputs_hidden": false
}
"collapsed": false
},
"outputs": [],
"source": []