{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "accelerator": "GPU", "colab": { "name": "Female_names_generator_using_charRNN.ipynb", "provenance": [], "collapsed_sections": [] }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.9" } }, "cells": [ { "cell_type": "markdown", "metadata": { "id": "SF2XaOyEqjcU", "colab_type": "text" }, "source": [ "#
CS568:Deep Learning
Spring 2020
" ] }, { "cell_type": "code", "metadata": { "id": "MweShd4nqqMd", "colab_type": "code", "colab": { "base_uri": "https://localhost:8080/", "height": 122 }, "outputId": "153374a2-e92e-4b4c-e10f-d4cf03b07703" }, "source": [ "from google.colab import drive\n", "drive.mount('/content/drive')" ], "execution_count": 1, "outputs": [ { "output_type": "stream", "text": [ "Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3aietf%3awg%3aoauth%3a2.0%3aoob&response_type=code&scope=email%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdocs.test%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive.photos.readonly%20https%3a%2f%2fwww.googleapis.com%2fauth%2fpeopleapi.readonly\n", "\n", "Enter your authorization code:\n", "··········\n", "Mounted at /content/drive\n" ], "name": "stdout" } ] }, { "cell_type": "markdown", "metadata": { "id": "SirtjPxJqjcV", "colab_type": "text" }, "source": [ "## Load Data" ] }, { "cell_type": "code", "metadata": { "colab_type": "code", "id": "HXA8rCMRJEBH", "colab": {} }, "source": [ "import pandas as pd\n", "import numpy as np\n", "from keras.models import Sequential\n", "from keras import layers\n", "from keras.optimizers import Adam\n", "import sys\n", "\n", "max_length = 2\n", "learning_rate = 0.01\n", "num_epochs = 100\n", "batch_size = 1000\n", "\n", "def load_data(filename):\n", " df = pd.read_csv(filename)\n", " df = df.filter(['Name'])\n", " df = np.array(df)\n", " return df\n", "\n", "def preprocess_data(data): \n", " inputs = []\n", " targets = [] \n", " vocab = '' \n", " for item in data: \n", " item = str(np.squeeze(item)) \n", " \n", " # track all possible characters to generate\n", " vocab += item\n", " \n", " # create tokens from each name\n", " for i in range(len(item) - max_length):\n", " inputs.append(item[i : i + max_length])\n", " targets.append(item[i + max_length])\n", "\n", " # get list of unique characters to generate from\n", " chars = sorted(list(set(vocab)))\n", " data_size, chars_size = len(data), len(chars)\n", " print(\"Data has {} characters, {} unique\".format(str(data_size), str(chars_size)))\n", "\n", " char_indices = dict((ch, chars.index(ch)) for ch in chars)\n", " \n", " # create empty numpy arrays for X and y\n", " X = np.zeros((len(inputs), max_length, chars_size), dtype=np.bool)\n", " t = np.zeros((len(inputs), chars_size), dtype=np.bool)\n", "\n", " # one-hot encode selections\n", " for inp, indiv_input in enumerate(inputs):\n", " for tar, indiv_char in enumerate(indiv_input):\n", " X[inp, tar, char_indices[indiv_char]] = 1\n", " t[inp, char_indices[targets[inp]]] = 1\n", " \n", " return X, t, chars, char_indices" ], "execution_count": 0, "outputs": [] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "mgs62W3NMbr1" }, "source": [ "## Define model" ] }, { "cell_type": "code", "metadata": { "colab_type": "code", "id": "82ABbbwyL5eZ", "colab": {} }, "source": [ "def build_model(chars_length):\n", " model = Sequential()\n", " model.add(layers.LSTM(128, input_shape=(max_length, chars_length)))\n", " model.add(layers.Dense(chars_length, activation='softmax')) \n", " model.compile(loss='categorical_crossentropy', optimizer= Adam(lr = learning_rate)) \n", " return model" ], "execution_count": 0, "outputs": [] }, { "cell_type": "code", "metadata": { "colab_type": "code", "id": "BMFcVFdZOX-e", "colab": {} }, "source": [ "def sample(output, total):\n", " output = np.asarray(output).astype('float64')\n", " output = np.log(output) / total\n", " exp_output = np.exp(output)\n", " output = exp_output / np.sum(exp_output)\n", " probas = np.random.multinomial(1, output, 1)\n", " out = np.argmax(probas)\n", " return out\n", " \n", "def generate_names(seed, length, chars, char_indices, char_length):\n", " generated_text = seed\n", " name = seed\n", "\n", " for i in range(length - len(seed)):\n", " sampled = np.zeros((1, max_length, char_length))\n", "\n", " for tar, char in enumerate(generated_text): \n", " # print(generated_text) \n", " # print(char, char_indices[char], tar) \n", " sampled[0, tar, char_indices[char]] = 1. \n", "\n", " preds = model.predict(sampled, verbose=0)[0]\n", " next_index = sample(preds, 0.5)\n", " next_char = chars[next_index]\n", "\n", " generated_text += next_char\n", " generated_text = generated_text[1:]\n", "\n", " name += next_char\n", " \n", " return name\n" ], "execution_count": 0, "outputs": [] }, { "cell_type": "code", "metadata": { "colab_type": "code", "id": "DDsA2NgKNYB0", "outputId": "f312858d-b83b-40ee-9157-da88c6fa9007", "colab": { "base_uri": "https://localhost:8080/", "height": 1000 } }, "source": [ "data = load_data(\"/content/drive/My Drive/muslim_girls_names.csv\")\n", "print(data.shape, data.dtype)\n", "X, t, chars, char_indices = preprocess_data(data)\n", "print(\"X.shape \", X.shape)\n", "print(\"t.shape \", t.shape)\n", "\n", "model = build_model(len(chars))\n", "model.summary()\n", "model.fit(X, t, epochs=num_epochs, batch_size=batch_size)" ], "execution_count": 6, "outputs": [ { "output_type": "stream", "text": [ "(4442, 1) object\n", "Data has 4442 characters, 52 unique\n", "X.shape (19500, 2, 52)\n", "t.shape (19500, 52)\n", "Model: \"sequential_1\"\n", "_________________________________________________________________\n", "Layer (type) Output Shape Param # \n", "=================================================================\n", "lstm_1 (LSTM) (None, 128) 92672 \n", "_________________________________________________________________\n", "dense_1 (Dense) (None, 52) 6708 \n", "=================================================================\n", "Total params: 99,380\n", "Trainable params: 99,380\n", "Non-trainable params: 0\n", "_________________________________________________________________\n", "Epoch 1/100\n", "19500/19500 [==============================] - 2s 105us/step - loss: 3.0127\n", "Epoch 2/100\n", "19500/19500 [==============================] - 0s 7us/step - loss: 2.2989\n", "Epoch 3/100\n", "19500/19500 [==============================] - 0s 7us/step - loss: 2.1260\n", "Epoch 4/100\n", "19500/19500 [==============================] - 0s 8us/step - loss: 2.0619\n", "Epoch 5/100\n", "19500/19500 [==============================] - 0s 8us/step - loss: 2.0289\n", "Epoch 6/100\n", "19500/19500 [==============================] - 0s 7us/step - loss: 2.0117\n", "Epoch 7/100\n", "19500/19500 [==============================] - 0s 8us/step - loss: 1.9948\n", "Epoch 8/100\n", "19500/19500 [==============================] - 0s 9us/step - loss: 1.9815\n", "Epoch 9/100\n", "19500/19500 [==============================] - 0s 9us/step - loss: 1.9688\n", "Epoch 10/100\n", "19500/19500 [==============================] - 0s 8us/step - loss: 1.9593\n", "Epoch 11/100\n", "19500/19500 [==============================] - 0s 8us/step - loss: 1.9502\n", "Epoch 12/100\n", "19500/19500 [==============================] - 0s 7us/step - loss: 1.9409\n", "Epoch 13/100\n", "19500/19500 [==============================] - 0s 7us/step - loss: 1.9341\n", "Epoch 14/100\n", "19500/19500 [==============================] - 0s 7us/step - loss: 1.9243\n", "Epoch 15/100\n", "19500/19500 [==============================] - 0s 8us/step - loss: 1.9186\n", "Epoch 16/100\n", "19500/19500 [==============================] - 0s 8us/step - loss: 1.9122\n", "Epoch 17/100\n", "19500/19500 [==============================] - 0s 8us/step - loss: 1.9065\n", "Epoch 18/100\n", "19500/19500 [==============================] - 0s 8us/step - loss: 1.8993\n", "Epoch 19/100\n", "19500/19500 [==============================] - 0s 7us/step - loss: 1.8987\n", "Epoch 20/100\n", "19500/19500 [==============================] - 0s 7us/step - loss: 1.8901\n", "Epoch 21/100\n", "19500/19500 [==============================] - 0s 8us/step - loss: 1.8880\n", "Epoch 22/100\n", "19500/19500 [==============================] - 0s 7us/step - loss: 1.8842\n", "Epoch 23/100\n", "19500/19500 [==============================] - 0s 8us/step - loss: 1.8794\n", "Epoch 24/100\n", "19500/19500 [==============================] - 0s 7us/step - loss: 1.8791\n", "Epoch 25/100\n", "19500/19500 [==============================] - 0s 8us/step - loss: 1.8728\n", "Epoch 26/100\n", "19500/19500 [==============================] - 0s 7us/step - loss: 1.8672\n", "Epoch 27/100\n", "19500/19500 [==============================] - 0s 7us/step - loss: 1.8680\n", "Epoch 28/100\n", "19500/19500 [==============================] - 0s 7us/step - loss: 1.8627\n", "Epoch 29/100\n", "19500/19500 [==============================] - 0s 8us/step - loss: 1.8590\n", "Epoch 30/100\n", "19500/19500 [==============================] - 0s 7us/step - loss: 1.8584\n", "Epoch 31/100\n", "19500/19500 [==============================] - 0s 7us/step - loss: 1.8529\n", "Epoch 32/100\n", "19500/19500 [==============================] - 0s 8us/step - loss: 1.8527\n", "Epoch 33/100\n", "19500/19500 [==============================] - 0s 7us/step - loss: 1.8516\n", "Epoch 34/100\n", "19500/19500 [==============================] - 0s 8us/step - loss: 1.8460\n", "Epoch 35/100\n", "19500/19500 [==============================] - 0s 8us/step - loss: 1.8450\n", "Epoch 36/100\n", "19500/19500 [==============================] - 0s 7us/step - loss: 1.8438\n", "Epoch 37/100\n", "19500/19500 [==============================] - 0s 7us/step - loss: 1.8414\n", "Epoch 38/100\n", "19500/19500 [==============================] - 0s 7us/step - loss: 1.8409\n", "Epoch 39/100\n", "19500/19500 [==============================] - 0s 7us/step - loss: 1.8400\n", "Epoch 40/100\n", "19500/19500 [==============================] - 0s 7us/step - loss: 1.8382\n", "Epoch 41/100\n", "19500/19500 [==============================] - 0s 7us/step - loss: 1.8348\n", "Epoch 42/100\n", "19500/19500 [==============================] - 0s 7us/step - loss: 1.8321\n", "Epoch 43/100\n", "19500/19500 [==============================] - 0s 8us/step - loss: 1.8323\n", "Epoch 44/100\n", "19500/19500 [==============================] - 0s 8us/step - loss: 1.8294\n", "Epoch 45/100\n", "19500/19500 [==============================] - 0s 8us/step - loss: 1.8307\n", "Epoch 46/100\n", "19500/19500 [==============================] - 0s 8us/step - loss: 1.8257\n", "Epoch 47/100\n", "19500/19500 [==============================] - 0s 7us/step - loss: 1.8269\n", "Epoch 48/100\n", "19500/19500 [==============================] - 0s 7us/step - loss: 1.8253\n", "Epoch 49/100\n", "19500/19500 [==============================] - 0s 7us/step - loss: 1.8243\n", "Epoch 50/100\n", "19500/19500 [==============================] - 0s 9us/step - loss: 1.8226\n", "Epoch 51/100\n", "19500/19500 [==============================] - 0s 7us/step - loss: 1.8244\n", "Epoch 52/100\n", "19500/19500 [==============================] - 0s 7us/step - loss: 1.8210\n", "Epoch 53/100\n", "19500/19500 [==============================] - 0s 7us/step - loss: 1.8198\n", "Epoch 54/100\n", "19500/19500 [==============================] - 0s 7us/step - loss: 1.8186\n", "Epoch 55/100\n", "19500/19500 [==============================] - 0s 8us/step - loss: 1.8189\n", "Epoch 56/100\n", "19500/19500 [==============================] - 0s 7us/step - loss: 1.8169\n", "Epoch 57/100\n", "19500/19500 [==============================] - 0s 7us/step - loss: 1.8177\n", "Epoch 58/100\n", "19500/19500 [==============================] - 0s 8us/step - loss: 1.8151\n", "Epoch 59/100\n", "19500/19500 [==============================] - 0s 7us/step - loss: 1.8130\n", "Epoch 60/100\n", "19500/19500 [==============================] - 0s 7us/step - loss: 1.8129\n", "Epoch 61/100\n", "19500/19500 [==============================] - 0s 7us/step - loss: 1.8131\n", "Epoch 62/100\n", "19500/19500 [==============================] - 0s 7us/step - loss: 1.8118\n", "Epoch 63/100\n", "19500/19500 [==============================] - 0s 7us/step - loss: 1.8120\n", "Epoch 64/100\n", "19500/19500 [==============================] - 0s 9us/step - loss: 1.8132\n", "Epoch 65/100\n", "19500/19500 [==============================] - 0s 8us/step - loss: 1.8110\n", "Epoch 66/100\n", "19500/19500 [==============================] - 0s 7us/step - loss: 1.8089\n", "Epoch 67/100\n", "19500/19500 [==============================] - 0s 8us/step - loss: 1.8094\n", "Epoch 68/100\n", "19500/19500 [==============================] - 0s 8us/step - loss: 1.8080\n", "Epoch 69/100\n", "19500/19500 [==============================] - 0s 7us/step - loss: 1.8089\n", "Epoch 70/100\n", "19500/19500 [==============================] - 0s 9us/step - loss: 1.8094\n", "Epoch 71/100\n", "19500/19500 [==============================] - 0s 7us/step - loss: 1.8095\n", "Epoch 72/100\n", "19500/19500 [==============================] - 0s 7us/step - loss: 1.8076\n", "Epoch 73/100\n", "19500/19500 [==============================] - 0s 7us/step - loss: 1.8052\n", "Epoch 74/100\n", "19500/19500 [==============================] - 0s 8us/step - loss: 1.8069\n", "Epoch 75/100\n", "19500/19500 [==============================] - 0s 7us/step - loss: 1.8050\n", "Epoch 76/100\n", "19500/19500 [==============================] - 0s 7us/step - loss: 1.8043\n", "Epoch 77/100\n", "19500/19500 [==============================] - 0s 7us/step - loss: 1.8035\n", "Epoch 78/100\n", "19500/19500 [==============================] - 0s 8us/step - loss: 1.8027\n", "Epoch 79/100\n", "19500/19500 [==============================] - 0s 7us/step - loss: 1.8027\n", "Epoch 80/100\n", "19500/19500 [==============================] - 0s 7us/step - loss: 1.8040\n", "Epoch 81/100\n", "19500/19500 [==============================] - 0s 8us/step - loss: 1.8027\n", "Epoch 82/100\n", "19500/19500 [==============================] - 0s 7us/step - loss: 1.8024\n", "Epoch 83/100\n", "19500/19500 [==============================] - 0s 8us/step - loss: 1.8016\n", "Epoch 84/100\n", "19500/19500 [==============================] - 0s 8us/step - loss: 1.8004\n", "Epoch 85/100\n", "19500/19500 [==============================] - 0s 7us/step - loss: 1.7987\n", "Epoch 86/100\n", "19500/19500 [==============================] - 0s 7us/step - loss: 1.8002\n", "Epoch 87/100\n", "19500/19500 [==============================] - 0s 7us/step - loss: 1.8012\n", "Epoch 88/100\n", "19500/19500 [==============================] - 0s 8us/step - loss: 1.7994\n", "Epoch 89/100\n", "19500/19500 [==============================] - 0s 8us/step - loss: 1.8012\n", "Epoch 90/100\n", "19500/19500 [==============================] - 0s 8us/step - loss: 1.7982\n", "Epoch 91/100\n", "19500/19500 [==============================] - 0s 8us/step - loss: 1.7992\n", "Epoch 92/100\n", "19500/19500 [==============================] - 0s 7us/step - loss: 1.7991\n", "Epoch 93/100\n", "19500/19500 [==============================] - 0s 7us/step - loss: 1.7975\n", "Epoch 94/100\n", "19500/19500 [==============================] - 0s 7us/step - loss: 1.7972\n", "Epoch 95/100\n", "19500/19500 [==============================] - 0s 7us/step - loss: 1.7991\n", "Epoch 96/100\n", "19500/19500 [==============================] - 0s 7us/step - loss: 1.7972\n", "Epoch 97/100\n", "19500/19500 [==============================] - 0s 8us/step - loss: 1.7973\n", "Epoch 98/100\n", "19500/19500 [==============================] - 0s 7us/step - loss: 1.7965\n", "Epoch 99/100\n", "19500/19500 [==============================] - 0s 7us/step - loss: 1.7970\n", "Epoch 100/100\n", "19500/19500 [==============================] - 0s 7us/step - loss: 1.7955\n" ], "name": "stdout" }, { "output_type": "execute_result", "data": { "text/plain": [ "" ] }, "metadata": { "tags": [] }, "execution_count": 6 } ] }, { "cell_type": "code", "metadata": { "colab_type": "code", "id": "pKFXJqYIOwSf", "outputId": "df92996d-1a0e-4114-9487-9e38bf89157a", "colab": { "base_uri": "https://localhost:8080/", "height": 34 } }, "source": [ "generate_names(\"Ab\", 7, chars, char_indices, len(chars))" ], "execution_count": 11, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "'Abeerah'" ] }, "metadata": { "tags": [] }, "execution_count": 11 } ] }, { "cell_type": "code", "metadata": { "id": "jAh4eO25qjcs", "colab_type": "code", "colab": { "base_uri": "https://localhost:8080/", "height": 34 }, "outputId": "b8e6fe4f-04fd-4002-d8ec-076e612758d7" }, "source": [ "generate_names(\"An\", 7, chars, char_indices, len(chars))" ], "execution_count": 17, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "'Anaheen'" ] }, "metadata": { "tags": [] }, "execution_count": 17 } ] }, { "cell_type": "code", "metadata": { "id": "Yp-IT69XrUvU", "colab_type": "code", "colab": { "base_uri": "https://localhost:8080/", "height": 34 }, "outputId": "22bc5b8f-a665-4f61-fe2f-b337e6c91af4" }, "source": [ "generate_names(\"Mi\", 7, chars, char_indices, len(chars))" ], "execution_count": 18, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "'Mishadi'" ] }, "metadata": { "tags": [] }, "execution_count": 18 } ] }, { "cell_type": "code", "metadata": { "id": "bTAkewJtrYIu", "colab_type": "code", "colab": { "base_uri": "https://localhost:8080/", "height": 34 }, "outputId": "eb55284e-c90e-4c04-96a6-c2c67844ad59" }, "source": [ "generate_names(\"Ma\", 5, chars, char_indices, len(chars))" ], "execution_count": 31, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "'Mahid'" ] }, "metadata": { "tags": [] }, "execution_count": 31 } ] }, { "cell_type": "code", "metadata": { "id": "II-QYNrGrjqB", "colab_type": "code", "colab": { "base_uri": "https://localhost:8080/", "height": 34 }, "outputId": "5f567fa4-6165-4ac3-95b7-605a86f2cddf" }, "source": [ "generate_names(\"Z\", 6, chars, char_indices, len(chars))" ], "execution_count": 34, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "'Zairam'" ] }, "metadata": { "tags": [] }, "execution_count": 34 } ] }, { "cell_type": "code", "metadata": { "id": "CV5BpbumrpSJ", "colab_type": "code", "colab": { "base_uri": "https://localhost:8080/", "height": 34 }, "outputId": "1beb1a06-c1b7-404e-b29a-42e4ffc7005c" }, "source": [ "generate_names(\"Za\", 8, chars, char_indices, len(chars))" ], "execution_count": 38, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "'Zareenah'" ] }, "metadata": { "tags": [] }, "execution_count": 38 } ] } ] }