fix: 同時接続でエラーになる問題の解消策として、DBセッションの取り回し方法を修正した(まずは共通部分の未修正)

This commit is contained in:
shimoda.m@nds-tyo.co.jp 2023-08-31 11:30:58 +09:00
parent bed91795f1
commit 266e7cd28b
10 changed files with 132 additions and 81 deletions

View File

View File

@ -0,0 +1,12 @@
"""FastAPIサーバーの起動イベントのラッパー"""
from typing import Callable
from src.db.tasks import init_db
def create_start_app_handler() -> Callable:
def start_app() -> None:
init_db()
return start_app

View File

@ -1,6 +1,7 @@
from sqlalchemy import (Connection, CursorResult, Engine, NullPool,
create_engine, event, exc, text)
from sqlalchemy import (CursorResult, Engine, NullPool, create_engine, event,
exc, text)
from sqlalchemy.engine.url import URL
from sqlalchemy.orm import Session, sessionmaker
from sqlalchemy.pool import Pool
from src.error.exceptions import DBException
@ -24,9 +25,8 @@ def ping_connection(dbapi_connection, connection_record, connection_proxy):
cursor.close()
class Database:
"""データベース操作クラス"""
__connection: Connection = None
class DatabaseSession:
"""データベースセッション管理クラス"""
__engine: Engine = None
__host: str = None
__port: str = None
@ -34,6 +34,7 @@ class Database:
__password: str = None
__schema: str = None
__connection_string: str = None
__instance = None
def __init__(self, username: str, password: str, host: str, port: int, schema: str) -> None:
"""このクラスの新たなインスタンスを初期化します
@ -73,32 +74,49 @@ class Database:
Returns:
Database: DB操作クラスインスタンス
"""
return cls(
username=environment.DB_USERNAME,
password=environment.DB_PASSWORD,
host=environment.DB_HOST,
port=environment.DB_PORT,
schema=environment.DB_SCHEMA
)
if cls.__instance is None:
cls.__instance = cls(
username=environment.DB_USERNAME,
password=environment.DB_PASSWORD,
host=environment.DB_HOST,
port=environment.DB_PORT,
schema=environment.DB_SCHEMA
)
@property
def connection(self):
return cls.__instance
def create_session(self):
"""
DBの接続を返します
"""
return self.__connection
def connect(self):
"""
DBに接続します接続に失敗した場合リトライします
Raises:
DBException: 接続失敗
"""
try:
self.__connection = self.__engine.connect()
return sessionmaker(autoflush=False, bind=self.__engine)()
except Exception as e:
raise DBException(e)
class DatabaseClient:
__session: Session = None
def __init__(self, session: Session) -> None:
self.__session = session
@property
def session(self) -> Session:
"""
DBのセッションを返します
"""
if self.__session is None:
raise DBException('DBに接続していません')
return self.__session
def connect(self):
...
def disconnect(self):
...
def execute_select(self, select_query: str, parameters=None) -> list[dict]:
"""SELECTクエリを実行します。
@ -112,14 +130,11 @@ class Database:
Returns:
list[dict]: カラム名: 値の辞書リスト
"""
if self.__connection is None:
raise DBException('DBに接続していません')
result = None
try:
# トランザクションが開始している場合は、トランザクションを引き継ぐ
if self.__connection.in_transaction():
result = self.__connection.execute(text(select_query), parameters)
if self.session.in_transaction():
result = self.session.execute(text(select_query), parameters)
else:
# トランザクションが明示的に開始していない場合は、クエリ単位でトランザクションをbegin-commitする。
result = self.__execute_with_transaction(select_query, parameters)
@ -142,14 +157,11 @@ class Database:
Returns:
CursorResult: 取得結果
"""
if self.__connection is None:
raise DBException('DBに接続していません')
result = None
try:
# トランザクションが開始している場合は、トランザクションを引き継ぐ
if self.__connection.in_transaction():
result = self.__connection.execute(text(query), parameters)
if self.session.in_transaction():
result = self.session.execute(text(query), parameters)
else:
# トランザクションが明示的に開始していない場合は、クエリ単位でトランザクションをbegin-commitする。
result = self.__execute_with_transaction(query, parameters)
@ -160,35 +172,35 @@ class Database:
def begin(self):
"""トランザクションを開始します。"""
if not self.__connection.in_transaction():
self.__connection.begin()
if not self.session.in_transaction():
self.session.begin()
def commit(self):
"""トランザクションをコミットします"""
if self.__connection.in_transaction():
self.__connection.commit()
if self.session.in_transaction():
self.session.commit()
def rollback(self):
"""トランザクションをロールバックします"""
if self.__connection.in_transaction():
self.__connection.rollback()
if self.session.in_transaction():
self.session.rollback()
def disconnect(self):
"""DB接続を切断します。"""
if self.__connection is not None:
self.__connection.close()
self.__connection = None
# def disconnect(self):
# """DB接続を切断します。"""
# if self.__session is not None:
# self.__session.close()
# self.session = None
def to_jst(self):
self.execute('SET time_zone = "+9:00"')
def __execute_with_transaction(self, query: str, parameters: dict):
# トランザクションを開始してクエリを実行する
with self.__connection.begin():
with self.session.begin():
try:
result = self.__connection.execute(text(query), parameters=parameters)
result = self.session.execute(text(query), parameters)
except Exception as e:
self.__connection.rollback()
self.session.rollback()
raise e
# ここでコミットされる
return result

View File

@ -0,0 +1,7 @@
from src.db.database import DatabaseSession
def init_db() -> None:
# DB接続モジュールを初期化
# 以降、get_instance()で唯一のインスタンスを取得する
DatabaseSession.get_instance()

View File

@ -1,13 +1,15 @@
from starlette.requests import Request
from src.db.database import Database
from src.db.database import DatabaseClient, DatabaseSession
def get_database(request: Request) -> Database:
# medaca_routerでDB接続エンジンが初期化される
db = getattr(request.app.state, '_db', None)
# uvicornのワーカーが起動したタイミングでは、dbがセットされていないので、ここでセットここでセットする
if db is None:
db = Database.get_instance()
setattr(request.app.state, '_db', db)
return db
def get_database(request: Request) -> DatabaseClient:
try:
database_session = DatabaseSession.get_instance()
session = database_session.create_session()
database = DatabaseClient(session)
yield database
finally:
database.disconnect()
# FIXME; リポジトリを直し終えたら消す
session.close()

View File

@ -2,13 +2,13 @@ from typing import Callable, Type
from fastapi import Depends
from src.db.database import Database
from src.db.database import DatabaseClient
from src.depends.database import get_database
from src.services.base_service import BaseService
def get_service(Service_type: Type[BaseService]) -> Callable:
def get_service(db: Database = Depends(get_database)) -> Type[BaseService]:
def get_service(db: DatabaseClient = Depends(get_database)) -> Type[BaseService]:
repositories = {key: repository(db) for key, repository in Service_type.REPOSITORIES.items()}
clients = {key: client() for key, client in Service_type.CLIENTS.items()}
return Service_type(repositories=repositories, clients=clients)

View File

@ -27,7 +27,7 @@ def get_logger(log_name: str) -> logging.Logger:
logger.addHandler(handler)
formatter = logging.Formatter(
'%(name)s\t[%(levelname)s]\t%(asctime)s\t%(message)s',
'%(name)s[%(process)d][%(thread)d]\t[%(levelname)s]\t%(asctime)s\t%(message)s',
'%Y-%m-%d %H:%M:%S'
)

View File

@ -1,14 +1,18 @@
import os.path as path
from fastapi import FastAPI
from fastapi import Depends, FastAPI
from fastapi.staticfiles import StaticFiles
from starlette import status
import src.static as static
from src.controller import (bio, bio_download, healthcheck, login, logout,
master_mainte, menu, root, ultmarc)
from src.core import task
from src.depends.services import get_service
from src.error.exception_handler import http_exception_handler
from src.error.exceptions import UnexpectedException
from src.logging.get_logger import get_logger
from src.services.batch_status_service import BatchStatusService
app = FastAPI()
@ -40,3 +44,21 @@ app.add_exception_handler(status.HTTP_403_FORBIDDEN, http_exception_handler)
# サーバーエラーが発生した場合のハンドラー。HTTPExceptionではハンドリングできないため、個別に設定
app.add_exception_handler(UnexpectedException, http_exception_handler)
# サーバー起動時のイベント
app.add_event_handler('startup', task.create_start_app_handler)
# logger = get_logger(__name__)
# @app.get('/sample/')
# def sample(service: BatchStatusService = Depends(get_service(BatchStatusService))):
# # import os
# # import threading
# logger.info('START')
# res = service.hdke_table_record
# logger.info(res)
# logger.info('END')
# return res
# # return f'{os.getpid()}, {threading.get_ident()}'

View File

@ -3,15 +3,15 @@ from abc import ABCMeta
import pandas as pd
from sqlalchemy import text
from src.db.database import Database
from src.db.database import DatabaseClient
from src.model.db.base_db_model import BaseDBModel
class BaseRepository(metaclass=ABCMeta):
_database: Database
_database: DatabaseClient
def __init__(self, db: Database) -> None:
def __init__(self, db: DatabaseClient) -> None:
self._database = db
def fetch_all(self) -> list[BaseDBModel]:
@ -36,7 +36,7 @@ class BaseRepository(metaclass=ABCMeta):
sql_query = pd.read_sql(
text(query),
con=self._database.connection,
con=self._database.session.connection(),
params=params)
df = pd.DataFrame(

View File

@ -5,7 +5,7 @@ from fastapi.exceptions import HTTPException
from fastapi.routing import APIRoute
from starlette import status
from src.db.database import Database
# from src.db.database import Database
from src.depends.auth import (check_session_expired, get_current_session,
verify_session)
from src.error.exceptions import DBException, UnexpectedException
@ -78,16 +78,16 @@ class MeDaCaRoute(APIRoute):
return response
class PrepareDatabaseRoute(MeDaCaRoute):
"""事前処理として、データベースのエンジンを作成するルートハンドラー
Args:
MeDaCaRoute (MeDaCaRoute): 共通ルートハンドラー
"""
async def pre_process_route(self, request: Request):
request = await super().pre_process_route(request)
# DBエンジンを構築して状態にセット
request.app.state._db = Database.get_instance()
return request
# class PrepareDatabaseRoute(MeDaCaRoute):
# """事前処理として、データベースのエンジンを作成するルートハンドラー
# Args:
# MeDaCaRoute (MeDaCaRoute): 共通ルートハンドラー
# """
# async def pre_process_route(self, request: Request):
# request = await super().pre_process_route(request)
# # DBエンジンを構築して状態にセット
# request.app.state._db = Database.get_instance()
# return request
class BeforeCheckSessionRoute(MeDaCaRoute):
@ -112,12 +112,8 @@ class BeforeCheckSessionRoute(MeDaCaRoute):
return session_request
class AfterSetCookieSessionRoute(PrepareDatabaseRoute):
"""事後処理として、セッションキーをcookieに設定するカスタムルートハンドラー
Args:
PrepareDatabaseRoute (PrepareDatabaseRoute): DBエンジンセットアップルートハンドラー
"""
class AfterSetCookieSessionRoute(MeDaCaRoute): # (PrepareDatabaseRoute):
"""事後処理として、セッションキーをcookieに設定するカスタムルートハンドラー"""
async def post_process_route(self, request: Request, response: Response):
response = await super().post_process_route(request, response)
session_key = response.headers.get('session_key', None)