feat: DBの接続関連を修正。UpperCaseにするのをやめた

This commit is contained in:
shimoda.m@nds-tyo.co.jp 2023-03-30 15:26:23 +09:00
parent 471ccec9bd
commit f323dcedd0
9 changed files with 144 additions and 57 deletions

View File

@ -1,71 +1,147 @@
from sqlalchemy import Engine, create_engine, text
from sqlalchemy import (Connection, CursorResult, Engine, QueuePool,
create_engine, text)
from sqlalchemy.engine.create import create_engine
from sqlalchemy.engine.url import URL
from src.error.exceptions import DBException
from src.system_var import environment
class Database:
"""データベース操作クラス"""
__connection: Connection = None
__engine: Engine = None
__host: str = None
__port: str = None
__user: str = None
__username: str = None
__password: str = None
__schema: str = None
__connection_string:str = None
def __init__(self, host: str, port: int, user: str, password: str, schema: str) -> None:
def __init__(self, username: str, password: str, host: str, port: int, schema: str) -> None:
"""このクラスの新たなインスタンスを初期化します
Args:
username (str): DBユーザー名
password (str): DBパスワード
host (str): DBホスト名
port (int): DBポート
schema (str): DBスキーマ名
"""
self.__username = username
self.__password = password
self.__host = host
self.__port = int(port)
self.__user = user
self.__password = password
self.__schema = schema
self.__connection_string = URL.create(
drivername='mysql+pymysql',
username=self.__user,
username=self.__username,
password=self.__password,
host=self.__host,
port=self.__port,
database=self.__schema,
query={"charset": "utf8mb4"}
)
self.__engine = create_engine(
self.__connection_string,
pool_timeout=5,
poolclass=QueuePool,
isolation_level="AUTOCOMMIT"
)
@classmethod
def get_instance(cls):
"""インスタンスを取得します
Returns:
Database: DB操作クラスインスタンス
"""
return cls(
username=environment.DB_USERNAME,
password=environment.DB_PASSWORD,
host=environment.DB_HOST,
port=environment.DB_PORT,
schema=environment.DB_SCHEMA
)
@property
def connection(self):
return self.__engine
"""
DBの接続を返します
"""
return self.__connection
def engine_init(self):
engine = create_engine(
self.__connection_string,
pool_timeout=5
)
self.__engine = engine
def connect(self):
"""
DBに接続します接続に失敗した場合リトライします
Raises:
DBException: 接続失敗
"""
self.__connection = self.__engine.connect()
def execute_query(self, select_query: str, parameters=None) -> list[dict]:
if self.__engine is None:
raise DBException('データベースが初期化されていません')
def execute_select(self, select_query: str, parameters=None) -> list[dict]:
"""SELECTクエリを実行します。
with self.__engine.begin() as trx:
try:
result = trx.execute(text(select_query), parameters=parameters)
except Exception as e:
trx.rollback()
raise DBException(e)
return result.mappings().all()
Args:
select_query (str): SELECT文
parameters (dict, optional): クエリのプレースホルダーに埋め込む変数の辞書. Defaults to None.
def execute(self, query: str, parameters=None) -> None:
if self.__engine is None:
raise DBException('データベースが初期化されていません')
Raises:
DBException: DBエラー
with self.__engine.begin() as trx:
try:
trx.execute(text(query), parameters=parameters)
except Exception as e:
trx.rollback()
raise DBException(e)
Returns:
list[dict]: カラム名: 値の辞書リスト
"""
if self.__connection is None:
raise DBException('DBに接続していません')
try:
result = self.__connection.execute(text(select_query), parameters=parameters)
except Exception as e:
raise DBException(e)
result_rows = result.mappings().all()
return result_rows
def close(self):
if self.__engine is not None:
self.__engine.dispose(close=True)
self.__engine = None
def execute(self, query: str, parameters=None) -> CursorResult:
"""SQLクエリを実行します。
Args:
query (str): SQL文
parameters (dict, optional): クエリのプレースホルダーに埋め込む変数の辞書. Defaults to None.
Raises:
DBException: DBエラー
Returns:
CursorResult: 取得結果
"""
if self.__connection is None:
raise DBException('DBに接続していません')
try:
result = self.__connection.execute(text(query), parameters=parameters)
except Exception as e:
raise DBException(e)
return result
def begin(self):
"""トランザクションを開始します。"""
if not self.__connection.in_transaction():
self.__connection.begin()
def commit(self):
"""トランザクションをコミットします"""
if self.__connection.in_transaction():
self.__connection.commit()
def rollback(self):
"""トランザクションをロールバックします"""
if self.__connection.in_transaction():
self.__connection.rollback()
def disconnect(self):
"""DB接続を切断します。"""
if self.__connection is not None:
self.__connection.close()
self.__connection = None

View File

@ -1,22 +1,14 @@
from fastapi import FastAPI
from src.db.database import Database
from src.system_var import environment
def init_db(app: FastAPI) -> None:
# DB接続モジュールを初期化
database = Database(
host=environment.DB_HOST,
port=environment.DB_PORT,
user=environment.DB_USERNAME,
password=environment.DB_PASSWORD,
schema=environment.DB_SCHEMA
)
database.engine_init()
database = Database.get_instance()
# FastAPI App内で使える変数として追加
app.state._db = database
def close_db(app: FastAPI) -> None:
app.state._db.close()
app.state._db = None

View File

@ -36,6 +36,7 @@ class BaseRepository(metaclass=ABCMeta):
text(query),
con=self._database.connection,
params=params)
df = pd.DataFrame(
sql_query,
index=None

View File

@ -38,12 +38,13 @@ class BioSalesViewRepository(BaseRepository):
def fetch_many(self, parameter: BioModel) -> list[BioSalesViewModel]:
try:
self._database.connect()
where_clause = self.__build_condition(parameter)
# error_log(date("Y/m/d H:i:s") . " [INFO] DB Return=" . $result . "\r\n", 3, "$execLog");
# error_log(date("Y/m/d H:i:s") . " [INFO] DB参照実行" . "\r\n", 3, "$execLog");
query = self.FETCH_SQL.format(where_clause=where_clause)
# error_log(date("Y/m/d H:i:s") . " [INFO] SQL: " . $query . "\r\n", 3, "$execLog");
result = self._database.execute_query(query, parameter.dict())
result = self._database.execute_select(query, parameter.dict())
models = [BioSalesViewModel(**r) for r in result]
# error_log(date("Y/m/d H:i:s") . " [INFO] count=" . $count . "\r\n", 3, "$execLog");
return models
@ -51,15 +52,18 @@ class BioSalesViewRepository(BaseRepository):
# TODO: ファイルへの書き出しはloggerでやる
print(f"[ERROR] DB Error : Exception={e.args}")
raise e
finally:
self._database.disconnect()
def fetch_as_data_frame(self, parameter: BioModel):
try:
self._database.connect()
where_clause = self.__build_condition(parameter)
# error_log(date("Y/m/d H:i:s") . " [INFO] DB Return=" . $result . "\r\n", 3, "$execLog");
# error_log(date("Y/m/d H:i:s") . " [INFO] DB参照実行" . "\r\n", 3, "$execLog");
query = self.FETCH_SQL.format(where_clause=where_clause)
# error_log(date("Y/m/d H:i:s") . " [INFO] SQL: " . $query . "\r\n", 3, "$execLog");
# result = self._database.execute_query(query, parameter.dict())
# models = [BioSalesViewModel(**r) for r in result]
# error_log(date("Y/m/d H:i:s") . " [INFO] count=" . $count . "\r\n", 3, "$execLog");
df = self._to_data_frame(query, parameter)
@ -68,6 +72,8 @@ class BioSalesViewRepository(BaseRepository):
# TODO: ファイルへの書き出しはloggerでやる
print(f"[ERROR] DB Error : Exception={e.args}")
raise e
finally:
self._database.disconnect()
def __build_condition(self, parameter: BioModel):
where_clauses: list[SQLCondition] = []

View File

@ -8,11 +8,14 @@ class HdkeTblRepository(BaseRepository):
def fetch_all(self) -> list[HdkeTblModel]:
try:
self._database.connect()
query = self.FETCH_SQL
result = self._database.execute_query(query)
result = self._database.execute_select(query)
models = [HdkeTblModel(**r) for r in result]
return models
except Exception as e:
# TODO: ファイルへの書き出しはloggerでやる
print(f"[ERROR] DB Error : Exception={e.args}")
raise e
finally:
self._database.disconnect()

View File

@ -26,7 +26,8 @@ class PharmacyProductMasterRepository(BaseRepository):
def fetch_all(self) -> list[PharmacyProductMasterModel]:
try:
result = self._database.execute_query(self.FETCH_SQL)
self._database.connect()
result = self._database.execute_select(self.FETCH_SQL)
models = [PharmacyProductMasterModel(**r) for r in result]
return models
except Exception as e:
@ -34,3 +35,5 @@ class PharmacyProductMasterRepository(BaseRepository):
print(f"[ERROR] getOroshiData DB Error. ")
print(f"[ERROR] ErrorMessage: {e.args}")
raise e
finally:
self._database.disconnect()

View File

@ -10,13 +10,14 @@ class UserMasterRepository(BaseRepository):
FROM
src05.user_mst
WHERE
LOWER(user_id) = LOWER(:user_id)\
user_id = :user_id\
"""
def fetch_one(self, parameter: dict) -> UserMasterModel:
try:
self._database.connect()
query = self.FETCH_SQL
result = self._database.execute_query(query, parameter)
result = self._database.execute_select(query, parameter)
models = [UserMasterModel(**r) for r in result]
if len(models) == 0:
return None
@ -25,3 +26,5 @@ class UserMasterRepository(BaseRepository):
# TODO: ファイルへの書き出しはloggerでやる
print(f"[ERROR] DB Error : Exception={e.args}")
raise e
finally:
self._database.disconnect()

View File

@ -24,7 +24,8 @@ class WholesalerMasterRepository(BaseRepository):
def fetch_all(self) -> list[WholesalerMasterModel]:
try:
result = self._database.execute_query(self.FETCH_SQL)
self._database.connect()
result = self._database.execute_select(self.FETCH_SQL)
result_data = [res for res in result]
models = [WholesalerMasterModel(**r) for r in result_data]
return models
@ -33,3 +34,5 @@ class WholesalerMasterRepository(BaseRepository):
print(f"[ERROR] getOroshiData DB Error. ")
print(f"[ERROR] ErrorMessage: {e.args}")
raise e
finally:
self._database.disconnect()

View File

@ -72,7 +72,7 @@ class BioViewService(BaseService):
def write_excel_file(self, data_frame: pd.DataFrame, user_id: str, timestamp: datetime):
# Excelに書き込み
output_file_path = path.join(constants.BIO_TEMPORARY_FILE_DIR_PATH, f'Result_{user_id.upper()}_{timestamp:%Y%m%d%H%M%S%f}.xlsx')
output_file_path = path.join(constants.BIO_TEMPORARY_FILE_DIR_PATH, f'Result_{user_id}_{timestamp:%Y%m%d%H%M%S%f}.xlsx')
# テンプレートファイルをコピーして出力ファイルの枠だけを作る
shutil.copyfile(
@ -92,7 +92,7 @@ class BioViewService(BaseService):
def write_csv_file(self, data_frame: pd.DataFrame, user_id: str, header: list[str], timestamp: datetime):
# csvに書き込み
output_file_path = path.join(constants.BIO_TEMPORARY_FILE_DIR_PATH, f'Result_{user_id.upper()}_{timestamp:%Y%m%d%H%M%S%f}.csv')
output_file_path = path.join(constants.BIO_TEMPORARY_FILE_DIR_PATH, f'Result_{user_id}_{timestamp:%Y%m%d%H%M%S%f}.csv')
# 横長のDataFrameとするため、ヘッダーの加工処理
header_data = {}
for df_column, header_column in zip(data_frame.columns, header):
@ -115,5 +115,5 @@ class BioViewService(BaseService):
bucket_name = environment.BIO_ACCESS_LOG_BUCKET
# TODO: フォルダを変える
file_key = f'bio/{path.basename(local_file_path)}'
download_filename = f'{user_id.upper()}_生物由来卸販売データ.{kind}'
download_filename = f'{user_id}_生物由来卸販売データ.{kind}'
return self.s3_client.generate_presigned_url(bucket_name, file_key, download_filename)