716 lines
19 KiB
Plaintext
716 lines
19 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "7d017333",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Final Assessment Scratch Pad"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "d3d00386",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Instructions"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "ea516aa7",
|
|
"metadata": {},
|
|
"source": [
|
|
"1. Please use only this Jupyter notebook to work on your model, and **do not use any extra files**. If you need to define helper classes or functions, feel free to do so in this notebook.\n",
|
|
"2. This template is intended to be general, but it may not cover every use case. The sections are given so that it will be easier for us to grade your submission. If your specific use case isn't addressed, **you may add new Markdown or code blocks to this notebook**. However, please **don't delete any existing blocks**.\n",
|
|
"3. If you don't think a particular section of this template is necessary for your work, **you may skip it**. Be sure to explain clearly why you decided to do so."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "022cb4cd",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Report"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "9c14a2d8",
|
|
"metadata": {},
|
|
"source": [
|
|
"**[TODO]**\n",
|
|
"\n",
|
|
"Please provide a summary of the ideas and steps that led you to your final model. Someone reading this summary should understand why you chose to approach the problem in a particular way and able to replicate your final model at a high level. Please ensure that your summary is detailed enough to provide an overview of your thought process and approach but also concise enough to be easily understandable. Also, please follow the guidelines given in the `main.ipynb`.\n",
|
|
"\n",
|
|
"This report should not be longer than **1-2 pages of A4 paper (up to around 1,000 words)**. Marks will be deducted if you do not follow instructions and you include too many words here. \n",
|
|
"\n",
|
|
"**[DELETE EVERYTHING FROM THE PREVIOUS TODO TO HERE BEFORE SUBMISSION]**\n",
|
|
"\n",
|
|
"##### Overview\n",
|
|
"**[TODO]**\n",
|
|
"\n",
|
|
"##### 1. Descriptive Analysis\n",
|
|
"First step: Look at the target values. The target values are floats and NAs, which is interesting. NAs in the target data is a bit suspicious. However, despite being floats, the target values are actually ordinal. I'll convert them to ordinal values by just `Y.fillna(-1).astype(int).`. Now, I can do value counts and see that there are only 7 distinct values, including NaN. I will regard this as a classification problem with 7 classes.\n",
|
|
"\n",
|
|
"Looking at the `X`, I realise each entry in the list is an `n` by 16 by 16 matrix. 16 by 16 matrix, my first idea is to look at them like images. Plotting the images showed no relevant info. `6 <= n <= 10`.\n",
|
|
"\n",
|
|
"I just realised this is a video dataset. I'll pad all the frames to be of size 10. so that i'll have a 2500 x 10 x 16 x 16 video datset. \n",
|
|
"\n",
|
|
"##### 2. Detection and Handling of Missing Values\n",
|
|
"**[TODO]**\n",
|
|
"\n",
|
|
"##### 3. Detection and Handling of Outliers\n",
|
|
"**[TODO]**\n",
|
|
"\n",
|
|
"##### 4. Detection and Handling of Class Imbalance \n",
|
|
"**[TODO]**\n",
|
|
"\n",
|
|
"##### 5. Understanding Relationship Between Variables\n",
|
|
"**[TODO]**\n",
|
|
"\n",
|
|
"##### 6. Data Visualization\n",
|
|
"**[TODO]** \n",
|
|
"##### 7. General Preprocessing\n",
|
|
"**[TODO]**\n",
|
|
" \n",
|
|
"##### 8. Feature Selection \n",
|
|
"**[TODO]**\n",
|
|
"\n",
|
|
"##### 9. Feature Engineering\n",
|
|
"**[TODO]**\n",
|
|
"\n",
|
|
"##### 10. Creating Models\n",
|
|
"**[TODO]**\n",
|
|
"\n",
|
|
"##### 11. Model Evaluation\n",
|
|
"**[TODO]**\n",
|
|
"\n",
|
|
"##### 12. Hyperparameters Search\n",
|
|
"**[TODO]**\n",
|
|
"\n",
|
|
"##### Conclusion\n",
|
|
"**[TODO]**"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "49dcaf29",
|
|
"metadata": {},
|
|
"source": [
|
|
"---"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "27103374",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Workings (Not Graded)\n",
|
|
"\n",
|
|
"You will do your working below. Note that anything below this section will not be graded, but we might counter-check what you wrote in the report above with your workings to make sure that you actually did what you claimed to have done. "
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "0f4c6cd4",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Import Packages\n",
|
|
"\n",
|
|
"Here, we import some packages necessary to run this notebook. In addition, you may import other packages as well. Do note that when submitting your model, you may only use packages that are available in Coursemology (see `main.ipynb`)."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"id": "cded1ed6",
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2024-04-27T16:12:36.411884Z",
|
|
"start_time": "2024-04-27T16:12:35.911757Z"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"import pandas\n",
|
|
"import pandas as pd\n",
|
|
"import os\n",
|
|
"import numpy as np\n",
|
|
"import matplotlib.pyplot as plt"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "748c35d7",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Load Dataset\n",
|
|
"\n",
|
|
"The dataset `data.npy` consists of $N$ grayscale videos and their corresponding labels. Each video has a shape of (L, H, W). L represents the length of the video, which may vary between videos. H and W represent the height and width, which are consistent across all videos. \n",
|
|
"\n",
|
|
"A code snippet that loads the data is provided below."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "c09da291",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Load Data"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"id": "6297e25a",
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2024-04-27T16:12:36.450725Z",
|
|
"start_time": "2024-04-27T16:12:36.412962Z"
|
|
}
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Number of data sample: 2500\n",
|
|
"Shape of the first data sample: (10, 16, 16)\n",
|
|
"Shape of the third data sample: (8, 16, 16)\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"with open('data.npy', 'rb') as f:\n",
|
|
" data = np.load(f, allow_pickle=True).item()\n",
|
|
" X = data['data']\n",
|
|
" y = data['label']\n",
|
|
" \n",
|
|
"print('Number of data sample:', len(X))\n",
|
|
"print('Shape of the first data sample:', X[0].shape)\n",
|
|
"print('Shape of the third data sample:', X[2].shape)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "cbe832b6",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Data Exploration & Preparation"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"outputs": [],
|
|
"source": [
|
|
"from sklearn.preprocessing import OrdinalEncoder\n",
|
|
"# Some Helper Functions\n",
|
|
"def show_images(images, n_row=5, n_col=5, figsize=[12,12]):\n",
|
|
" _, axs = plt.subplots(n_row, n_col, figsize=figsize)\n",
|
|
" axs = axs.flatten()\n",
|
|
" for img, ax in zip(images, axs):\n",
|
|
" ax.imshow(img, cmap='gray')\n",
|
|
" plt.show()\n",
|
|
"def nan_columns(X, threshold=0.5):\n",
|
|
" count = X.shape[0] * threshold\n",
|
|
" nan_columns = X.isna().sum()\n",
|
|
" return nan_columns[nan_columns >= count].index\n",
|
|
"def zero_columns(X, threshold=0.5):\n",
|
|
" count = X.shape[0] * threshold\n",
|
|
" zero_cols = (X == 0).sum()\n",
|
|
" return zero_cols[zero_cols >= count].index\n",
|
|
"\n",
|
|
"def object_columns(X):\n",
|
|
" return X.dtypes[X.dtypes == 'object'].index\n",
|
|
"\n",
|
|
"def convert_to_ordinal(X, columns):\n",
|
|
" encoder = OrdinalEncoder()\n",
|
|
" return encoder.fit_transform(X[columns])\n",
|
|
"\n",
|
|
"def correlated_columns(X, threshold=0.99):\n",
|
|
" corr = X.corr()\n",
|
|
" upper = corr.where(np.triu(np.ones(corr.shape), k=1).astype(bool))\n",
|
|
" return [column for column in upper.columns if any(upper[column] > threshold)]"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"ExecuteTime": {
|
|
"end_time": "2024-04-27T16:12:36.816993Z",
|
|
"start_time": "2024-04-27T16:12:36.451526Z"
|
|
}
|
|
},
|
|
"id": "f68b8b1c21eae6d6",
|
|
"execution_count": 3
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "2f6a464c",
|
|
"metadata": {},
|
|
"source": [
|
|
"### 1. Descriptive Analysis"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"2250\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"not_nan_indices = np.argwhere(~np.isnan(np.array(y))).squeeze()\n",
|
|
"print(len(not_nan_indices))\n",
|
|
"y_filtered = [y[i] for i in not_nan_indices]\n",
|
|
"x_filtered = [X[i] for i in not_nan_indices]\n",
|
|
"X = x_filtered\n",
|
|
"y = y_filtered\n",
|
|
"# show_images(X[0], 2, 5, [16, 16])\n",
|
|
"Y = pd.DataFrame(y)\n",
|
|
"# show_images(X[0], 1, 10, [10, 1])\n",
|
|
"# show_images(X[1], 1, 10, [10, 1])\n",
|
|
"# show_images(X[2], 1, 10, [10, 1])\n",
|
|
"# show_images(X[3], 1, 10, [10, 1])\n",
|
|
"# Y[:10].T\n",
|
|
"# print(type(X[0]))"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"ExecuteTime": {
|
|
"end_time": "2024-04-27T16:12:36.821537Z",
|
|
"start_time": "2024-04-27T16:12:36.818392Z"
|
|
}
|
|
},
|
|
"id": "3b1f62dd",
|
|
"execution_count": 4
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"(2250, 10, 256)\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"# We can now try to pad the videos to be of size 10\n",
|
|
"\n",
|
|
"def process_video(video):\n",
|
|
" L = video.shape[0]\n",
|
|
" if L < 10:\n",
|
|
" return np.concatenate([video, np.zeros((10 - L, 16, 16))]).reshape(10, -1)\n",
|
|
" return video.reshape(10, -1).astype(np.float32)\n",
|
|
"\n",
|
|
"L_max = 10\n",
|
|
"X_array = np.zeros((len(X), 10, 256))\n",
|
|
"for i, video in enumerate(X):\n",
|
|
" X_array[i] = process_video(video)\n",
|
|
"np.expand_dims(X_array, axis=2).shape\n",
|
|
"print(X_array.shape)\n",
|
|
"X_array = np.reshape(X_array, (X_array.shape[0], X_array.shape[1], 256)).shape\n",
|
|
"# flattened_data = print(flattened_data)"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"ExecuteTime": {
|
|
"end_time": "2024-04-27T16:12:36.841198Z",
|
|
"start_time": "2024-04-27T16:12:36.822424Z"
|
|
}
|
|
},
|
|
"id": "558f2d74562bc7c8",
|
|
"execution_count": 5
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "adb61967",
|
|
"metadata": {},
|
|
"source": [
|
|
"### 2. Detection and Handling of Missing Values"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"id": "4bb9cdfb",
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2024-04-27T16:12:36.843437Z",
|
|
"start_time": "2024-04-27T16:12:36.842009Z"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": []
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "8adcb9cd",
|
|
"metadata": {},
|
|
"source": [
|
|
"### 3. Detection and Handling of Outliers"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"id": "ed1c17a1",
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2024-04-27T16:12:36.845318Z",
|
|
"start_time": "2024-04-27T16:12:36.843930Z"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": []
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "d4916043",
|
|
"metadata": {},
|
|
"source": [
|
|
"### 4. Detection and Handling of Class Imbalance"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"id": "ad3ab20e",
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2024-04-27T16:12:36.847266Z",
|
|
"start_time": "2024-04-27T16:12:36.845985Z"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": []
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "2552a795",
|
|
"metadata": {},
|
|
"source": [
|
|
"### 5. Understanding Relationship Between Variables"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"id": "29ddbbcf",
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2024-04-27T16:12:36.849483Z",
|
|
"start_time": "2024-04-27T16:12:36.848012Z"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": []
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "757fb315",
|
|
"metadata": {},
|
|
"source": [
|
|
"### 6. Data Visualization"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"id": "93f82e42",
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2024-04-27T16:12:36.852862Z",
|
|
"start_time": "2024-04-27T16:12:36.851617Z"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": []
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "2a7eebcf",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Data Preprocessing"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "ae3e3383",
|
|
"metadata": {},
|
|
"source": [
|
|
"### 7. General Preprocessing"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"id": "19174365",
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2024-04-27T16:12:36.854714Z",
|
|
"start_time": "2024-04-27T16:12:36.853430Z"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": []
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "fb3aa527",
|
|
"metadata": {},
|
|
"source": [
|
|
"### 8. Feature Selection"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"id": "a85808bf",
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2024-04-27T16:12:36.856476Z",
|
|
"start_time": "2024-04-27T16:12:36.855157Z"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": []
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "4921e8ca",
|
|
"metadata": {},
|
|
"source": [
|
|
"### 9. Feature Engineering"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"id": "dbcde626",
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2024-04-27T16:12:36.858522Z",
|
|
"start_time": "2024-04-27T16:12:36.857080Z"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": []
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "fa676c3f",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Modeling & Evaluation"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "589b37e4",
|
|
"metadata": {},
|
|
"source": [
|
|
"### 10. Creating models"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 6,
|
|
"id": "d8dffd7d",
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2024-04-27T16:12:37.543384Z",
|
|
"start_time": "2024-04-27T16:12:36.859114Z"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"import torch\n",
|
|
"from torch import nn"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "495bf3c0",
|
|
"metadata": {},
|
|
"source": [
|
|
"### 11. Model Evaluation"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 7,
|
|
"id": "9245ab47",
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2024-04-27T16:12:37.992484Z",
|
|
"start_time": "2024-04-27T16:12:37.544103Z"
|
|
}
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"/var/folders/zd/9vyg32393qncxwt_3r_873mh0000gn/T/ipykernel_51446/3747572966.py:7: UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/utils/tensor_new.cpp:278.)\n",
|
|
" X_tensor = torch.tensor(X_train, dtype=torch.float32)\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"from sklearn.model_selection import train_test_split\n",
|
|
"# Split train and test\n",
|
|
"X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1)\n",
|
|
"X_train = [process_video(video) for video in X_train]\n",
|
|
"X_test = [process_video(video) for video in X_test]\n",
|
|
"\n",
|
|
"X_tensor = torch.tensor(X_train, dtype=torch.float32)\n",
|
|
"y_tensor = torch.tensor(y_train, dtype=torch.float32)\n",
|
|
"\n",
|
|
"train_dataset = torch.utils.data.TensorDataset(X_tensor, y_tensor)\n",
|
|
"train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"outputs": [],
|
|
"source": [
|
|
"class VideoLSTM(nn.Module):\n",
|
|
" def __init__(self):\n",
|
|
" super(VideoLSTM, self).__init__()\n",
|
|
" self.input_size = 256\n",
|
|
" self.hidden_layers = 64\n",
|
|
" self.num_layers = 1\n",
|
|
" self.num_classes = 6\n",
|
|
" \n",
|
|
" self.lstm = nn.LSTM(self.input_size, self.hidden_layers, self.num_layers, batch_first=True)\n",
|
|
" self.fc = nn.Linear(self.hidden_layers, self.num_classes)\n",
|
|
" def forward(self, x):\n",
|
|
" h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_layers).to(x.device)\n",
|
|
" c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_layers).to(x.device)\n",
|
|
"\n",
|
|
" # Forward propagate LSTM\n",
|
|
" print('prelstm')\n",
|
|
" out, _ = self.lstm(x, (h0, c0))\n",
|
|
" print('postlstm')\n",
|
|
" \n",
|
|
" out = self.fc(out[:, -1, :])\n",
|
|
" print('postout')\n",
|
|
" return out "
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"ExecuteTime": {
|
|
"end_time": "2024-04-27T16:12:37.996839Z",
|
|
"start_time": "2024-04-27T16:12:37.993120Z"
|
|
}
|
|
},
|
|
"id": "7396b295037aa70f",
|
|
"execution_count": 8
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"outputs": [],
|
|
"source": [],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"ExecuteTime": {
|
|
"end_time": "2024-04-27T16:12:37.998501Z",
|
|
"start_time": "2024-04-27T16:12:37.997472Z"
|
|
}
|
|
},
|
|
"id": "9057629fbaaa8571",
|
|
"execution_count": 8
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"outputs": [],
|
|
"source": [
|
|
"def train_model(model, loss_fn, optimizer, train_loader, num_epochs=10):\n",
|
|
" model.train()\n",
|
|
" for epoch in range(num_epochs):\n",
|
|
" running_loss = 0.0\n",
|
|
" for inputs, labels in train_loader:\n",
|
|
" optimizer.zero_grad()\n",
|
|
" outputs = model(inputs)\n",
|
|
" loss = loss_fn(outputs, labels)\n",
|
|
" loss.backward()\n",
|
|
" optimizer.step()\n",
|
|
" running_loss += loss.item()\n",
|
|
" print(f\"Epoch {epoch + 1}, Loss: {running_loss / len(train_loader)}\")\n"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"ExecuteTime": {
|
|
"end_time": "2024-04-27T16:12:38.001704Z",
|
|
"start_time": "2024-04-27T16:12:37.999290Z"
|
|
}
|
|
},
|
|
"id": "c3901cf56e12eade",
|
|
"execution_count": 9
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"outputs": [],
|
|
"source": [
|
|
"model = VideoLSTM()\n",
|
|
"lossFn = nn.CrossEntropyLoss()\n",
|
|
"optimizer = torch.optim.Adam(model.parameters(), lr=0.001)\n",
|
|
"train_model(model, lossFn, optimizer, train_loader, num_epochs=1)"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"is_executing": true,
|
|
"ExecuteTime": {
|
|
"start_time": "2024-04-27T16:12:17.275816Z"
|
|
}
|
|
},
|
|
"id": "dbb00fef60449a02",
|
|
"execution_count": null
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "8aa31404",
|
|
"metadata": {},
|
|
"source": [
|
|
"### 12. Hyperparameters Search"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "81addd51",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3 (ipykernel)",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.9.18"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|