From 1312e694c3d36dc306a30417bf248eeb4bedeeba Mon Sep 17 00:00:00 2001 From: Yadunand Prem Date: Sun, 28 Apr 2024 00:24:44 +0800 Subject: [PATCH] feat: broken LSTM --- cs2109s/labs/final/scratchpad.ipynb | 93 ++++++++++++++++------------- 1 file changed, 50 insertions(+), 43 deletions(-) diff --git a/cs2109s/labs/final/scratchpad.ipynb b/cs2109s/labs/final/scratchpad.ipynb index 4487999..dab5ba2 100644 --- a/cs2109s/labs/final/scratchpad.ipynb +++ b/cs2109s/labs/final/scratchpad.ipynb @@ -127,8 +127,8 @@ "id": "cded1ed6", "metadata": { "ExecuteTime": { - "end_time": "2024-04-27T16:12:36.411884Z", - "start_time": "2024-04-27T16:12:35.911757Z" + "end_time": "2024-04-27T16:15:03.602644Z", + "start_time": "2024-04-27T16:15:03.179277Z" } }, "outputs": [], @@ -166,8 +166,8 @@ "id": "6297e25a", "metadata": { "ExecuteTime": { - "end_time": "2024-04-27T16:12:36.450725Z", - "start_time": "2024-04-27T16:12:36.412962Z" + "end_time": "2024-04-27T16:15:06.411332Z", + "start_time": "2024-04-27T16:15:06.392391Z" } }, "outputs": [ @@ -236,8 +236,8 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2024-04-27T16:12:36.816993Z", - "start_time": "2024-04-27T16:12:36.451526Z" + "end_time": "2024-04-27T16:15:12.963319Z", + "start_time": "2024-04-27T16:15:12.025487Z" } }, "id": "f68b8b1c21eae6d6", @@ -281,8 +281,8 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2024-04-27T16:12:36.821537Z", - "start_time": "2024-04-27T16:12:36.818392Z" + "end_time": "2024-04-27T16:15:14.719386Z", + "start_time": "2024-04-27T16:15:14.712849Z" } }, "id": "3b1f62dd", @@ -320,8 +320,8 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2024-04-27T16:12:36.841198Z", - "start_time": "2024-04-27T16:12:36.822424Z" + "end_time": "2024-04-27T16:15:17.056545Z", + "start_time": "2024-04-27T16:15:17.014489Z" } }, "id": "558f2d74562bc7c8", @@ -525,8 +525,8 @@ "id": "d8dffd7d", "metadata": { "ExecuteTime": { - "end_time": "2024-04-27T16:12:37.543384Z", - "start_time": "2024-04-27T16:12:36.859114Z" + "end_time": "2024-04-27T16:15:21.497276Z", + "start_time": "2024-04-27T16:15:20.501754Z" } }, "outputs": [], @@ -545,24 +545,15 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 19, "id": "9245ab47", "metadata": { "ExecuteTime": { - "end_time": "2024-04-27T16:12:37.992484Z", - "start_time": "2024-04-27T16:12:37.544103Z" + "end_time": "2024-04-27T16:19:38.194596Z", + "start_time": "2024-04-27T16:19:37.776094Z" } }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/var/folders/zd/9vyg32393qncxwt_3r_873mh0000gn/T/ipykernel_51446/3747572966.py:7: UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/utils/tensor_new.cpp:278.)\n", - " X_tensor = torch.tensor(X_train, dtype=torch.float32)\n" - ] - } - ], + "outputs": [], "source": [ "from sklearn.model_selection import train_test_split\n", "# Split train and test\n", @@ -570,8 +561,10 @@ "X_train = [process_video(video) for video in X_train]\n", "X_test = [process_video(video) for video in X_test]\n", "\n", + "y_train = np.array(y_train).astype(np.int64)\n", + "\n", "X_tensor = torch.tensor(X_train, dtype=torch.float32)\n", - "y_tensor = torch.tensor(y_train, dtype=torch.float32)\n", + "y_tensor = torch.tensor(y_train, dtype=torch.long)\n", "\n", "train_dataset = torch.utils.data.TensorDataset(X_tensor, y_tensor)\n", "train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)" @@ -581,9 +574,9 @@ "cell_type": "code", "outputs": [], "source": [ - "class VideoLSTM(nn.Module):\n", + "class Model(nn.Module):\n", " def __init__(self):\n", - " super(VideoLSTM, self).__init__()\n", + " super(Model, self).__init__()\n", " self.input_size = 256\n", " self.hidden_layers = 64\n", " self.num_layers = 1\n", @@ -596,23 +589,20 @@ " c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_layers).to(x.device)\n", "\n", " # Forward propagate LSTM\n", - " print('prelstm')\n", " out, _ = self.lstm(x, (h0, c0))\n", - " print('postlstm')\n", " \n", " out = self.fc(out[:, -1, :])\n", - " print('postout')\n", " return out " ], "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2024-04-27T16:12:37.996839Z", - "start_time": "2024-04-27T16:12:37.993120Z" + "end_time": "2024-04-27T16:20:46.738811Z", + "start_time": "2024-04-27T16:20:46.734583Z" } }, "id": "7396b295037aa70f", - "execution_count": 8 + "execution_count": 25 }, { "cell_type": "code", @@ -648,31 +638,48 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2024-04-27T16:12:38.001704Z", - "start_time": "2024-04-27T16:12:37.999290Z" + "end_time": "2024-04-27T16:19:57.048696Z", + "start_time": "2024-04-27T16:19:57.045181Z" } }, "id": "c3901cf56e12eade", - "execution_count": 9 + "execution_count": 21 }, { "cell_type": "code", - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 1, Loss: nan\n", + "Epoch 2, Loss: nan\n", + "Epoch 3, Loss: nan\n", + "Epoch 4, Loss: nan\n", + "Epoch 5, Loss: nan\n", + "Epoch 6, Loss: nan\n", + "Epoch 7, Loss: nan\n", + "Epoch 8, Loss: nan\n", + "Epoch 9, Loss: nan\n", + "Epoch 10, Loss: nan\n" + ] + } + ], "source": [ - "model = VideoLSTM()\n", + "model = Model()\n", "lossFn = nn.CrossEntropyLoss()\n", "optimizer = torch.optim.Adam(model.parameters(), lr=0.001)\n", - "train_model(model, lossFn, optimizer, train_loader, num_epochs=1)" + "train_model(model, lossFn, optimizer, train_loader, num_epochs=10)" ], "metadata": { "collapsed": false, - "is_executing": true, "ExecuteTime": { - "start_time": "2024-04-27T16:12:17.275816Z" + "end_time": "2024-04-27T16:20:49.798810Z", + "start_time": "2024-04-27T16:20:48.477326Z" } }, "id": "dbb00fef60449a02", - "execution_count": null + "execution_count": 26 }, { "cell_type": "markdown",