PythonでSSH経由し外部のMySQLデータベースを操作してみる

Table of Contents

Table of Contents

ローカルで外部のデータベースを操作したい時に、いつもSequel proなどのツールを使って、SQL文を実行するようにしていますが、動的にSQL文を作成したいや結果確認も一気にやりたい時になると、少し不便かなーと感じています。

このような場合、Pythonを使ってみたらどうかなと思い、かつ最近もJupyterを試してみたいので、今回はJupyterを使って、SSH経由でMySQLデータベースを操作してみまして、その方法をメモしました。


やりたいこと

  • 踏み台(bastion)経由してMySQL DBに接続する
  • 更新対象のレコードを抽出し、更新SQL文を動的に作成する
  • データ更新を行い、進捗を表示する
  • 確認しやすくため、更新されたレコードの更新前の情報を保持する

Jupyter

今回は、Jupyterを使う前提でコードを作成します。



必要なmoduleを用意

まず、必要なmoduleをインポートします。

from sshtunnel import SSHTunnelForwarder
import pymysql as db
import pandas as pd
import datetime

SSHで踏み台サーバーとデータベースサーバーにアクセスするため、SSHTunnelForwarderをインポートします。

MySQLを操作するため、pymysqlをインポートします。

DataFrame形式でデータを操作したいため、pandasをインポートします。

操作時間を計算するため、datetimeをインポートします。


データベースサーバーの接続と検索関数の作成

踏み台経由でデータベースサーバーに接続して、検索SQLを実行する

# ssh
sshOptions = {
    "bastion": {
        "host": "host",
        "ssh_username": "user",
        "ssh_private_key": "~/develop/bastion.pem",
        "ssh_password": "password"
    }
}
ssh = sshOptions["bastion"]

# database
dbOptions = {
    "mysql": {
        "localhost": "host",
        "user": "user",
        "password": "password",
        "database": "db",
        "port": 3306
    }
}
dbConfig = dbOptions["mysql"]

def query(q):
    with SSHTunnelForwarder(
    (ssh["host"], 22),
    ssh_username = ssh["ssh_username"],
    ssh_password = ssh["ssh_password"],
    ssh_private_key = ssh["ssh_private_key"],
    remote_bind_address=(dbConfig["localhost"], dbConfig["port"])
    ) as server:
        conn = db.connect(host = '127.0.0.1',
                          port = server.local_bind_port,
                          user = dbConfig["user"],
                          passwd = dbConfig["password"],
                          db = dbConfig["database"],
                          charset = 'utf8',
                          cursorclass = db.cursors.DictCursor)
        df = pd.read_sql_query(q, conn)
        conn.close()
        
        return df

まず、sshOptions、dbOptionsには、踏み台サーバーのIP、ポート、SSH関する設定とデーターベースサーバーのIP、ポートを保持します。

SSHで踏み台サーバーに接続する部分は、SSHTunnelForwarderで処理します。 remote_bind_addressにデーターベースのIPとポートを設定します。

(踏み台を使わないままで直接データーベースに接続する場合、踏み台の部分をデータベースサーバーの情報に書き換えて、remote_bind_addressの1番目の引数を127.0.0.1にすればOK)

あとは、データを検索する関数query()を作成します。 SQL文を引数として、関数を実行すればDataFrame形式の結果を返してくれます。


進捗表示関数の作成

大量なデータ更新を実行している時に、進捗を確認したいですので、下記の関数を作成します。

def log_progress(sequence, every=None, size=None, name='Items'):
    from ipywidgets import IntProgress, HTML, VBox
    from IPython.display import display

    is_iterator = False
    if size is None:
        try:
            size = len(sequence)
        except TypeError:
            is_iterator = True
    if size is not None:
        if every is None:
            if size <= 200:
                every = 1
            else:
                every = int(size / 200)     # every 0.5%
    else:
        assert every is not None, 'sequence is iterator, set every'

    if is_iterator:
        progress = IntProgress(min=0, max=1, value=1)
        progress.bar_style = 'info'
    else:
        progress = IntProgress(min=0, max=size, value=0)
    label = HTML()
    box = VBox(children=[label, progress])
    display(box)

    index = 0
    try:
        for index, record in enumerate(sequence, 1):
            if index == 1 or index % every == 0:
                if is_iterator:
                    label.value = '{name}: {index} / ?'.format(
                        name=name,
                        index=index
                    )
                else:
                    progress.value = index
                    label.value = u'{name}: {index} / {size}'.format(
                        name=name,
                        index=index,
                        size=size
                    )
            yield record
    except:
        progress.bar_style = 'danger'
        raise
    else:
        progress.bar_style = 'success'
        progress.value = index
        label.value = "{name}: {index}".format(
            name=name,
            index=str(index or '?')
        )

使う例

for url in log_progress(urls[:10], every=1):
    # ループで実行したい処理
    # for example: print(url)

更新処理関数の作成

ここは、複数データを更新する想定です。 接続の処理はquery関数と同じです。

def update(targets):
    with SSHTunnelForwarder(
    (ssh["host"], 22),
    ssh_username = ssh["ssh_username"],
    ssh_password = 'ssh["ssh_password"]',
    ssh_private_key = ssh["ssh_private_key"],
    remote_bind_address=(dbConfig["localhost"], dbConfig["port"])
    ) as server:
        conn = db.connect(host = '127.0.0.1',
                          port = server.local_bind_port,
                          user = dbConfig["user"],
                          passwd = dbConfig["password"],
                          db = dbConfig["database"],
                          charset = 'utf8',
                          cursorclass = db.cursors.DictCursor)
        cursor = conn.cursor()
        record = []
        for index in log_progress(targets.index, every=1):
            data = targets.loc[index]
            sql2 = "update hoge set value = 0, updated_date=NOW() where id = %d"
            cursor.execute(sql2 % (data['id']))
            if (cursor.rowcount == 1):
                record.append('%d, %d, %s' % (data['id'] , data['value'], data['updated_date']))
            conn.commit()
        
        conn.close()
        return record

更新対象を関数の引数として渡し、ループで更新のSQL文を実行します。 ここは単にレコードのvalueとupdated_atの更新になります。

更新成功した場合、変更前のデータを保存する。 更新実行中に、進捗を表示したいため、log_progress関数を使用します。


対象レコードを抽出して更新を実行する

関数作成を完了しましたので、実際に例を実行してみます。

sql = "select id, value, updated_date \
from hoge \
where \
and value < 0 ;" 
df = query(sql)

startTime = datetime.datetime.now()
print(startTime, '-----', '', sep='\n')

re = update(df)

endTime = datetime.datetime.now()
print('------', endTime, '', sep='\n')
print('Time:', endTime - startTime)
print('the data before changing', ';\n'.join(re), sep='\n')

まず、先ほど作成したquery関数で更新したいレコードを抽出します。 ここは、value<0のレコードを抽出します。

その後、抽出したレコードをupdate関数に渡して、更新を実行します。

最後に、実行時間、変更されたデータの変更前の情報を表示します。