feat: broken LSTM
This commit is contained in:
parent
b39b40fa8a
commit
1312e694c3
@ -127,8 +127,8 @@
|
||||
"id": "cded1ed6",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2024-04-27T16:12:36.411884Z",
|
||||
"start_time": "2024-04-27T16:12:35.911757Z"
|
||||
"end_time": "2024-04-27T16:15:03.602644Z",
|
||||
"start_time": "2024-04-27T16:15:03.179277Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
@ -166,8 +166,8 @@
|
||||
"id": "6297e25a",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2024-04-27T16:12:36.450725Z",
|
||||
"start_time": "2024-04-27T16:12:36.412962Z"
|
||||
"end_time": "2024-04-27T16:15:06.411332Z",
|
||||
"start_time": "2024-04-27T16:15:06.392391Z"
|
||||
}
|
||||
},
|
||||
"outputs": [
|
||||
@ -236,8 +236,8 @@
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"ExecuteTime": {
|
||||
"end_time": "2024-04-27T16:12:36.816993Z",
|
||||
"start_time": "2024-04-27T16:12:36.451526Z"
|
||||
"end_time": "2024-04-27T16:15:12.963319Z",
|
||||
"start_time": "2024-04-27T16:15:12.025487Z"
|
||||
}
|
||||
},
|
||||
"id": "f68b8b1c21eae6d6",
|
||||
@ -281,8 +281,8 @@
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"ExecuteTime": {
|
||||
"end_time": "2024-04-27T16:12:36.821537Z",
|
||||
"start_time": "2024-04-27T16:12:36.818392Z"
|
||||
"end_time": "2024-04-27T16:15:14.719386Z",
|
||||
"start_time": "2024-04-27T16:15:14.712849Z"
|
||||
}
|
||||
},
|
||||
"id": "3b1f62dd",
|
||||
@ -320,8 +320,8 @@
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"ExecuteTime": {
|
||||
"end_time": "2024-04-27T16:12:36.841198Z",
|
||||
"start_time": "2024-04-27T16:12:36.822424Z"
|
||||
"end_time": "2024-04-27T16:15:17.056545Z",
|
||||
"start_time": "2024-04-27T16:15:17.014489Z"
|
||||
}
|
||||
},
|
||||
"id": "558f2d74562bc7c8",
|
||||
@ -525,8 +525,8 @@
|
||||
"id": "d8dffd7d",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2024-04-27T16:12:37.543384Z",
|
||||
"start_time": "2024-04-27T16:12:36.859114Z"
|
||||
"end_time": "2024-04-27T16:15:21.497276Z",
|
||||
"start_time": "2024-04-27T16:15:20.501754Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
@ -545,24 +545,15 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"execution_count": 19,
|
||||
"id": "9245ab47",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2024-04-27T16:12:37.992484Z",
|
||||
"start_time": "2024-04-27T16:12:37.544103Z"
|
||||
"end_time": "2024-04-27T16:19:38.194596Z",
|
||||
"start_time": "2024-04-27T16:19:37.776094Z"
|
||||
}
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/var/folders/zd/9vyg32393qncxwt_3r_873mh0000gn/T/ipykernel_51446/3747572966.py:7: UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/utils/tensor_new.cpp:278.)\n",
|
||||
" X_tensor = torch.tensor(X_train, dtype=torch.float32)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from sklearn.model_selection import train_test_split\n",
|
||||
"# Split train and test\n",
|
||||
@ -570,8 +561,10 @@
|
||||
"X_train = [process_video(video) for video in X_train]\n",
|
||||
"X_test = [process_video(video) for video in X_test]\n",
|
||||
"\n",
|
||||
"y_train = np.array(y_train).astype(np.int64)\n",
|
||||
"\n",
|
||||
"X_tensor = torch.tensor(X_train, dtype=torch.float32)\n",
|
||||
"y_tensor = torch.tensor(y_train, dtype=torch.float32)\n",
|
||||
"y_tensor = torch.tensor(y_train, dtype=torch.long)\n",
|
||||
"\n",
|
||||
"train_dataset = torch.utils.data.TensorDataset(X_tensor, y_tensor)\n",
|
||||
"train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)"
|
||||
@ -581,9 +574,9 @@
|
||||
"cell_type": "code",
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class VideoLSTM(nn.Module):\n",
|
||||
"class Model(nn.Module):\n",
|
||||
" def __init__(self):\n",
|
||||
" super(VideoLSTM, self).__init__()\n",
|
||||
" super(Model, self).__init__()\n",
|
||||
" self.input_size = 256\n",
|
||||
" self.hidden_layers = 64\n",
|
||||
" self.num_layers = 1\n",
|
||||
@ -596,23 +589,20 @@
|
||||
" c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_layers).to(x.device)\n",
|
||||
"\n",
|
||||
" # Forward propagate LSTM\n",
|
||||
" print('prelstm')\n",
|
||||
" out, _ = self.lstm(x, (h0, c0))\n",
|
||||
" print('postlstm')\n",
|
||||
" \n",
|
||||
" out = self.fc(out[:, -1, :])\n",
|
||||
" print('postout')\n",
|
||||
" return out "
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"ExecuteTime": {
|
||||
"end_time": "2024-04-27T16:12:37.996839Z",
|
||||
"start_time": "2024-04-27T16:12:37.993120Z"
|
||||
"end_time": "2024-04-27T16:20:46.738811Z",
|
||||
"start_time": "2024-04-27T16:20:46.734583Z"
|
||||
}
|
||||
},
|
||||
"id": "7396b295037aa70f",
|
||||
"execution_count": 8
|
||||
"execution_count": 25
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
@ -648,31 +638,48 @@
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"ExecuteTime": {
|
||||
"end_time": "2024-04-27T16:12:38.001704Z",
|
||||
"start_time": "2024-04-27T16:12:37.999290Z"
|
||||
"end_time": "2024-04-27T16:19:57.048696Z",
|
||||
"start_time": "2024-04-27T16:19:57.045181Z"
|
||||
}
|
||||
},
|
||||
"id": "c3901cf56e12eade",
|
||||
"execution_count": 9
|
||||
"execution_count": 21
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Epoch 1, Loss: nan\n",
|
||||
"Epoch 2, Loss: nan\n",
|
||||
"Epoch 3, Loss: nan\n",
|
||||
"Epoch 4, Loss: nan\n",
|
||||
"Epoch 5, Loss: nan\n",
|
||||
"Epoch 6, Loss: nan\n",
|
||||
"Epoch 7, Loss: nan\n",
|
||||
"Epoch 8, Loss: nan\n",
|
||||
"Epoch 9, Loss: nan\n",
|
||||
"Epoch 10, Loss: nan\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"model = VideoLSTM()\n",
|
||||
"model = Model()\n",
|
||||
"lossFn = nn.CrossEntropyLoss()\n",
|
||||
"optimizer = torch.optim.Adam(model.parameters(), lr=0.001)\n",
|
||||
"train_model(model, lossFn, optimizer, train_loader, num_epochs=1)"
|
||||
"train_model(model, lossFn, optimizer, train_loader, num_epochs=10)"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"is_executing": true,
|
||||
"ExecuteTime": {
|
||||
"start_time": "2024-04-27T16:12:17.275816Z"
|
||||
"end_time": "2024-04-27T16:20:49.798810Z",
|
||||
"start_time": "2024-04-27T16:20:48.477326Z"
|
||||
}
|
||||
},
|
||||
"id": "dbb00fef60449a02",
|
||||
"execution_count": null
|
||||
"execution_count": 26
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
|
Loading…
Reference in New Issue
Block a user