feat: broken LSTM

This commit is contained in:
Yadunand Prem 2024-04-28 00:24:44 +08:00
parent b39b40fa8a
commit 1312e694c3
No known key found for this signature in database

View File

@ -127,8 +127,8 @@
"id": "cded1ed6",
"metadata": {
"ExecuteTime": {
"end_time": "2024-04-27T16:12:36.411884Z",
"start_time": "2024-04-27T16:12:35.911757Z"
"end_time": "2024-04-27T16:15:03.602644Z",
"start_time": "2024-04-27T16:15:03.179277Z"
}
},
"outputs": [],
@ -166,8 +166,8 @@
"id": "6297e25a",
"metadata": {
"ExecuteTime": {
"end_time": "2024-04-27T16:12:36.450725Z",
"start_time": "2024-04-27T16:12:36.412962Z"
"end_time": "2024-04-27T16:15:06.411332Z",
"start_time": "2024-04-27T16:15:06.392391Z"
}
},
"outputs": [
@ -236,8 +236,8 @@
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-04-27T16:12:36.816993Z",
"start_time": "2024-04-27T16:12:36.451526Z"
"end_time": "2024-04-27T16:15:12.963319Z",
"start_time": "2024-04-27T16:15:12.025487Z"
}
},
"id": "f68b8b1c21eae6d6",
@ -281,8 +281,8 @@
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-04-27T16:12:36.821537Z",
"start_time": "2024-04-27T16:12:36.818392Z"
"end_time": "2024-04-27T16:15:14.719386Z",
"start_time": "2024-04-27T16:15:14.712849Z"
}
},
"id": "3b1f62dd",
@ -320,8 +320,8 @@
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-04-27T16:12:36.841198Z",
"start_time": "2024-04-27T16:12:36.822424Z"
"end_time": "2024-04-27T16:15:17.056545Z",
"start_time": "2024-04-27T16:15:17.014489Z"
}
},
"id": "558f2d74562bc7c8",
@ -525,8 +525,8 @@
"id": "d8dffd7d",
"metadata": {
"ExecuteTime": {
"end_time": "2024-04-27T16:12:37.543384Z",
"start_time": "2024-04-27T16:12:36.859114Z"
"end_time": "2024-04-27T16:15:21.497276Z",
"start_time": "2024-04-27T16:15:20.501754Z"
}
},
"outputs": [],
@ -545,24 +545,15 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 19,
"id": "9245ab47",
"metadata": {
"ExecuteTime": {
"end_time": "2024-04-27T16:12:37.992484Z",
"start_time": "2024-04-27T16:12:37.544103Z"
"end_time": "2024-04-27T16:19:38.194596Z",
"start_time": "2024-04-27T16:19:37.776094Z"
}
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/var/folders/zd/9vyg32393qncxwt_3r_873mh0000gn/T/ipykernel_51446/3747572966.py:7: UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/utils/tensor_new.cpp:278.)\n",
" X_tensor = torch.tensor(X_train, dtype=torch.float32)\n"
]
}
],
"outputs": [],
"source": [
"from sklearn.model_selection import train_test_split\n",
"# Split train and test\n",
@ -570,8 +561,10 @@
"X_train = [process_video(video) for video in X_train]\n",
"X_test = [process_video(video) for video in X_test]\n",
"\n",
"y_train = np.array(y_train).astype(np.int64)\n",
"\n",
"X_tensor = torch.tensor(X_train, dtype=torch.float32)\n",
"y_tensor = torch.tensor(y_train, dtype=torch.float32)\n",
"y_tensor = torch.tensor(y_train, dtype=torch.long)\n",
"\n",
"train_dataset = torch.utils.data.TensorDataset(X_tensor, y_tensor)\n",
"train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)"
@ -581,9 +574,9 @@
"cell_type": "code",
"outputs": [],
"source": [
"class VideoLSTM(nn.Module):\n",
"class Model(nn.Module):\n",
" def __init__(self):\n",
" super(VideoLSTM, self).__init__()\n",
" super(Model, self).__init__()\n",
" self.input_size = 256\n",
" self.hidden_layers = 64\n",
" self.num_layers = 1\n",
@ -596,23 +589,20 @@
" c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_layers).to(x.device)\n",
"\n",
" # Forward propagate LSTM\n",
" print('prelstm')\n",
" out, _ = self.lstm(x, (h0, c0))\n",
" print('postlstm')\n",
" \n",
" out = self.fc(out[:, -1, :])\n",
" print('postout')\n",
" return out "
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-04-27T16:12:37.996839Z",
"start_time": "2024-04-27T16:12:37.993120Z"
"end_time": "2024-04-27T16:20:46.738811Z",
"start_time": "2024-04-27T16:20:46.734583Z"
}
},
"id": "7396b295037aa70f",
"execution_count": 8
"execution_count": 25
},
{
"cell_type": "code",
@ -648,31 +638,48 @@
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-04-27T16:12:38.001704Z",
"start_time": "2024-04-27T16:12:37.999290Z"
"end_time": "2024-04-27T16:19:57.048696Z",
"start_time": "2024-04-27T16:19:57.045181Z"
}
},
"id": "c3901cf56e12eade",
"execution_count": 9
"execution_count": 21
},
{
"cell_type": "code",
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1, Loss: nan\n",
"Epoch 2, Loss: nan\n",
"Epoch 3, Loss: nan\n",
"Epoch 4, Loss: nan\n",
"Epoch 5, Loss: nan\n",
"Epoch 6, Loss: nan\n",
"Epoch 7, Loss: nan\n",
"Epoch 8, Loss: nan\n",
"Epoch 9, Loss: nan\n",
"Epoch 10, Loss: nan\n"
]
}
],
"source": [
"model = VideoLSTM()\n",
"model = Model()\n",
"lossFn = nn.CrossEntropyLoss()\n",
"optimizer = torch.optim.Adam(model.parameters(), lr=0.001)\n",
"train_model(model, lossFn, optimizer, train_loader, num_epochs=1)"
"train_model(model, lossFn, optimizer, train_loader, num_epochs=10)"
],
"metadata": {
"collapsed": false,
"is_executing": true,
"ExecuteTime": {
"start_time": "2024-04-27T16:12:17.275816Z"
"end_time": "2024-04-27T16:20:49.798810Z",
"start_time": "2024-04-27T16:20:48.477326Z"
}
},
"id": "dbb00fef60449a02",
"execution_count": null
"execution_count": 26
},
{
"cell_type": "markdown",