from sqlalchemy import (CursorResult, Engine, NullPool, create_engine, event, exc, text) from sqlalchemy.engine.url import URL from sqlalchemy.orm import Session, sessionmaker from sqlalchemy.pool import Pool from src.db.db_row_generator import DBRowGenerator from src.error.exceptions import DBException from src.logging.get_logger import get_logger from src.system_var import environment logger = get_logger('DB接続') @event.listens_for(Pool, 'checkout') def ping_connection(dbapi_connection, connection_record, connection_proxy): """コネクションが切れた場合、再接続""" cursor = dbapi_connection.cursor() try: cursor.execute("SELECT 1") except Exception as e: logger.info(f'DB接続に失敗したため、リトライします: {e}') # raise DisconnectionError - pool will try # connecting again up to three times before raising. raise exc.DisconnectionError() cursor.close() class DatabaseSession: """データベースセッション管理クラス""" __engine: Engine = None __host: str = None __port: str = None __username: str = None __password: str = None __schema: str = None __connection_string: str = None __instance = None def __init__(self, username: str, password: str, host: str, port: int, schema: str) -> None: """このクラスの新たなインスタンスを初期化します Args: username (str): DBユーザー名 password (str): DBパスワード host (str): DBホスト名 port (int): DBポート schema (str): DBスキーマ名 """ self.__username = username self.__password = password self.__host = host self.__port = int(port) self.__schema = schema self.__connection_string = URL.create( drivername='mysql+pymysql', username=self.__username, password=self.__password, host=self.__host, port=self.__port, database=self.__schema, query={"charset": "utf8mb4"} ) self.__engine = create_engine( self.__connection_string, poolclass=NullPool ) @classmethod def get_instance(cls): """インスタンスを取得します Returns: DatabaseSession: DB操作クラスインスタンス """ if cls.__instance is None: cls.__instance = cls( username=environment.DB_USERNAME, password=environment.DB_PASSWORD, host=environment.DB_HOST, port=environment.DB_PORT, schema=environment.DB_SCHEMA ) return cls.__instance def create_session(self) -> Session: """ DBの接続を返します。 Returns: Session: sqlalchemy.orm.Session """ try: return sessionmaker(autoflush=False, bind=self.__engine)() except Exception as e: raise DBException(e) class DatabaseClient: __session: Session = None def __init__(self, session: Session) -> None: self.__session = session @property def session(self) -> Session: """ DBのセッションを返します。 """ if self.__session is None: raise DBException('DBに接続していません') return self.__session def execute_select(self, select_query: str, parameters=None) -> list[dict]: """SELECTクエリを実行します。 Args: select_query (str): SELECT文 parameters (dict, optional): クエリのプレースホルダーに埋め込む変数の辞書. Defaults to None. Raises: DBException: DBエラー Returns: DBRowGenerator: カラム名: 値の辞書リストを返すジェネレータオブジェクト """ result = None try: # トランザクションが開始している場合は、トランザクションを引き継ぐ if self.session.in_transaction(): result = self.session.execute(text(select_query), parameters) else: # トランザクションが明示的に開始していない場合は、クエリ単位でトランザクションをbegin-commitする。 result = self.__execute_with_transaction(select_query, parameters) except Exception as e: raise DBException(e) result_rows = DBRowGenerator(result.mappings()) return result_rows def execute(self, query: str, parameters=None) -> CursorResult: """SQLクエリを実行します。 Args: query (str): SQL文 parameters (dict, optional): クエリのプレースホルダーに埋め込む変数の辞書. Defaults to None. Raises: DBException: DBエラー Returns: CursorResult: 取得結果 """ result = None try: # トランザクションが開始している場合は、トランザクションを引き継ぐ if self.session.in_transaction(): result = self.session.execute(text(query), parameters) else: # トランザクションが明示的に開始していない場合は、クエリ単位でトランザクションをbegin-commitする。 result = self.__execute_with_transaction(query, parameters) except Exception as e: raise DBException(e) return result def begin(self): """トランザクションを開始します。""" if not self.session.in_transaction(): self.session.begin() def commit(self): """トランザクションをコミットします""" if self.session.in_transaction(): self.session.commit() def rollback(self): """トランザクションをロールバックします""" if self.session.in_transaction(): self.session.rollback() def disconnect(self): """DB接続を切断します。""" self.session.close() self.__session = None def to_jst(self): self.execute('SET SESSION time_zone = "Asia/Tokyo"') def __execute_with_transaction(self, query: str, parameters: dict): # トランザクションを開始してクエリを実行する with self.session.begin(): try: result = self.session.execute(text(query), parameters) except Exception as e: self.session.rollback() raise e # ここでコミットされる return result