fix: 同時接続でエラーになる問題の解消策として、DBセッションの取り回し方法を修正した(まずは共通部分の未修正)
This commit is contained in:
parent
bed91795f1
commit
266e7cd28b
0
ecs/jskult-webapp/src/core/__init__.py
Normal file
0
ecs/jskult-webapp/src/core/__init__.py
Normal file
12
ecs/jskult-webapp/src/core/task.py
Normal file
12
ecs/jskult-webapp/src/core/task.py
Normal file
@ -0,0 +1,12 @@
|
||||
"""FastAPIサーバーの起動イベントのラッパー"""
|
||||
|
||||
from typing import Callable
|
||||
|
||||
from src.db.tasks import init_db
|
||||
|
||||
|
||||
def create_start_app_handler() -> Callable:
|
||||
def start_app() -> None:
|
||||
init_db()
|
||||
|
||||
return start_app
|
||||
@ -1,6 +1,7 @@
|
||||
from sqlalchemy import (Connection, CursorResult, Engine, NullPool,
|
||||
create_engine, event, exc, text)
|
||||
from sqlalchemy import (CursorResult, Engine, NullPool, create_engine, event,
|
||||
exc, text)
|
||||
from sqlalchemy.engine.url import URL
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
from sqlalchemy.pool import Pool
|
||||
|
||||
from src.error.exceptions import DBException
|
||||
@ -24,9 +25,8 @@ def ping_connection(dbapi_connection, connection_record, connection_proxy):
|
||||
cursor.close()
|
||||
|
||||
|
||||
class Database:
|
||||
"""データベース操作クラス"""
|
||||
__connection: Connection = None
|
||||
class DatabaseSession:
|
||||
"""データベースセッション管理クラス"""
|
||||
__engine: Engine = None
|
||||
__host: str = None
|
||||
__port: str = None
|
||||
@ -34,6 +34,7 @@ class Database:
|
||||
__password: str = None
|
||||
__schema: str = None
|
||||
__connection_string: str = None
|
||||
__instance = None
|
||||
|
||||
def __init__(self, username: str, password: str, host: str, port: int, schema: str) -> None:
|
||||
"""このクラスの新たなインスタンスを初期化します
|
||||
@ -73,32 +74,49 @@ class Database:
|
||||
Returns:
|
||||
Database: DB操作クラスインスタンス
|
||||
"""
|
||||
return cls(
|
||||
username=environment.DB_USERNAME,
|
||||
password=environment.DB_PASSWORD,
|
||||
host=environment.DB_HOST,
|
||||
port=environment.DB_PORT,
|
||||
schema=environment.DB_SCHEMA
|
||||
)
|
||||
if cls.__instance is None:
|
||||
cls.__instance = cls(
|
||||
username=environment.DB_USERNAME,
|
||||
password=environment.DB_PASSWORD,
|
||||
host=environment.DB_HOST,
|
||||
port=environment.DB_PORT,
|
||||
schema=environment.DB_SCHEMA
|
||||
)
|
||||
|
||||
@property
|
||||
def connection(self):
|
||||
return cls.__instance
|
||||
|
||||
def create_session(self):
|
||||
"""
|
||||
DBの接続を返します。
|
||||
"""
|
||||
return self.__connection
|
||||
|
||||
def connect(self):
|
||||
"""
|
||||
DBに接続します。接続に失敗した場合、リトライします。
|
||||
Raises:
|
||||
DBException: 接続失敗
|
||||
"""
|
||||
try:
|
||||
self.__connection = self.__engine.connect()
|
||||
return sessionmaker(autoflush=False, bind=self.__engine)()
|
||||
except Exception as e:
|
||||
raise DBException(e)
|
||||
|
||||
|
||||
class DatabaseClient:
|
||||
|
||||
__session: Session = None
|
||||
|
||||
def __init__(self, session: Session) -> None:
|
||||
self.__session = session
|
||||
|
||||
@property
|
||||
def session(self) -> Session:
|
||||
"""
|
||||
DBのセッションを返します。
|
||||
"""
|
||||
if self.__session is None:
|
||||
raise DBException('DBに接続していません')
|
||||
return self.__session
|
||||
|
||||
def connect(self):
|
||||
...
|
||||
|
||||
def disconnect(self):
|
||||
...
|
||||
|
||||
def execute_select(self, select_query: str, parameters=None) -> list[dict]:
|
||||
"""SELECTクエリを実行します。
|
||||
|
||||
@ -112,14 +130,11 @@ class Database:
|
||||
Returns:
|
||||
list[dict]: カラム名: 値の辞書リスト
|
||||
"""
|
||||
if self.__connection is None:
|
||||
raise DBException('DBに接続していません')
|
||||
|
||||
result = None
|
||||
try:
|
||||
# トランザクションが開始している場合は、トランザクションを引き継ぐ
|
||||
if self.__connection.in_transaction():
|
||||
result = self.__connection.execute(text(select_query), parameters)
|
||||
if self.session.in_transaction():
|
||||
result = self.session.execute(text(select_query), parameters)
|
||||
else:
|
||||
# トランザクションが明示的に開始していない場合は、クエリ単位でトランザクションをbegin-commitする。
|
||||
result = self.__execute_with_transaction(select_query, parameters)
|
||||
@ -142,14 +157,11 @@ class Database:
|
||||
Returns:
|
||||
CursorResult: 取得結果
|
||||
"""
|
||||
if self.__connection is None:
|
||||
raise DBException('DBに接続していません')
|
||||
|
||||
result = None
|
||||
try:
|
||||
# トランザクションが開始している場合は、トランザクションを引き継ぐ
|
||||
if self.__connection.in_transaction():
|
||||
result = self.__connection.execute(text(query), parameters)
|
||||
if self.session.in_transaction():
|
||||
result = self.session.execute(text(query), parameters)
|
||||
else:
|
||||
# トランザクションが明示的に開始していない場合は、クエリ単位でトランザクションをbegin-commitする。
|
||||
result = self.__execute_with_transaction(query, parameters)
|
||||
@ -160,35 +172,35 @@ class Database:
|
||||
|
||||
def begin(self):
|
||||
"""トランザクションを開始します。"""
|
||||
if not self.__connection.in_transaction():
|
||||
self.__connection.begin()
|
||||
if not self.session.in_transaction():
|
||||
self.session.begin()
|
||||
|
||||
def commit(self):
|
||||
"""トランザクションをコミットします"""
|
||||
if self.__connection.in_transaction():
|
||||
self.__connection.commit()
|
||||
if self.session.in_transaction():
|
||||
self.session.commit()
|
||||
|
||||
def rollback(self):
|
||||
"""トランザクションをロールバックします"""
|
||||
if self.__connection.in_transaction():
|
||||
self.__connection.rollback()
|
||||
if self.session.in_transaction():
|
||||
self.session.rollback()
|
||||
|
||||
def disconnect(self):
|
||||
"""DB接続を切断します。"""
|
||||
if self.__connection is not None:
|
||||
self.__connection.close()
|
||||
self.__connection = None
|
||||
# def disconnect(self):
|
||||
# """DB接続を切断します。"""
|
||||
# if self.__session is not None:
|
||||
# self.__session.close()
|
||||
# self.session = None
|
||||
|
||||
def to_jst(self):
|
||||
self.execute('SET time_zone = "+9:00"')
|
||||
|
||||
def __execute_with_transaction(self, query: str, parameters: dict):
|
||||
# トランザクションを開始してクエリを実行する
|
||||
with self.__connection.begin():
|
||||
with self.session.begin():
|
||||
try:
|
||||
result = self.__connection.execute(text(query), parameters=parameters)
|
||||
result = self.session.execute(text(query), parameters)
|
||||
except Exception as e:
|
||||
self.__connection.rollback()
|
||||
self.session.rollback()
|
||||
raise e
|
||||
# ここでコミットされる
|
||||
return result
|
||||
|
||||
7
ecs/jskult-webapp/src/db/tasks.py
Normal file
7
ecs/jskult-webapp/src/db/tasks.py
Normal file
@ -0,0 +1,7 @@
|
||||
from src.db.database import DatabaseSession
|
||||
|
||||
|
||||
def init_db() -> None:
|
||||
# DB接続モジュールを初期化
|
||||
# 以降、get_instance()で唯一のインスタンスを取得する
|
||||
DatabaseSession.get_instance()
|
||||
@ -1,13 +1,15 @@
|
||||
from starlette.requests import Request
|
||||
|
||||
from src.db.database import Database
|
||||
from src.db.database import DatabaseClient, DatabaseSession
|
||||
|
||||
|
||||
def get_database(request: Request) -> Database:
|
||||
# medaca_routerでDB接続エンジンが初期化される
|
||||
db = getattr(request.app.state, '_db', None)
|
||||
# uvicornのワーカーが起動したタイミングでは、dbがセットされていないので、ここでセットここでセットする
|
||||
if db is None:
|
||||
db = Database.get_instance()
|
||||
setattr(request.app.state, '_db', db)
|
||||
return db
|
||||
def get_database(request: Request) -> DatabaseClient:
|
||||
try:
|
||||
database_session = DatabaseSession.get_instance()
|
||||
session = database_session.create_session()
|
||||
database = DatabaseClient(session)
|
||||
yield database
|
||||
finally:
|
||||
database.disconnect()
|
||||
# FIXME; リポジトリを直し終えたら消す
|
||||
session.close()
|
||||
|
||||
@ -2,13 +2,13 @@ from typing import Callable, Type
|
||||
|
||||
from fastapi import Depends
|
||||
|
||||
from src.db.database import Database
|
||||
from src.db.database import DatabaseClient
|
||||
from src.depends.database import get_database
|
||||
from src.services.base_service import BaseService
|
||||
|
||||
|
||||
def get_service(Service_type: Type[BaseService]) -> Callable:
|
||||
def get_service(db: Database = Depends(get_database)) -> Type[BaseService]:
|
||||
def get_service(db: DatabaseClient = Depends(get_database)) -> Type[BaseService]:
|
||||
repositories = {key: repository(db) for key, repository in Service_type.REPOSITORIES.items()}
|
||||
clients = {key: client() for key, client in Service_type.CLIENTS.items()}
|
||||
return Service_type(repositories=repositories, clients=clients)
|
||||
|
||||
@ -27,7 +27,7 @@ def get_logger(log_name: str) -> logging.Logger:
|
||||
logger.addHandler(handler)
|
||||
|
||||
formatter = logging.Formatter(
|
||||
'%(name)s\t[%(levelname)s]\t%(asctime)s\t%(message)s',
|
||||
'%(name)s[%(process)d][%(thread)d]\t[%(levelname)s]\t%(asctime)s\t%(message)s',
|
||||
'%Y-%m-%d %H:%M:%S'
|
||||
)
|
||||
|
||||
|
||||
@ -1,14 +1,18 @@
|
||||
import os.path as path
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi import Depends, FastAPI
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from starlette import status
|
||||
|
||||
import src.static as static
|
||||
from src.controller import (bio, bio_download, healthcheck, login, logout,
|
||||
master_mainte, menu, root, ultmarc)
|
||||
from src.core import task
|
||||
from src.depends.services import get_service
|
||||
from src.error.exception_handler import http_exception_handler
|
||||
from src.error.exceptions import UnexpectedException
|
||||
from src.logging.get_logger import get_logger
|
||||
from src.services.batch_status_service import BatchStatusService
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
@ -40,3 +44,21 @@ app.add_exception_handler(status.HTTP_403_FORBIDDEN, http_exception_handler)
|
||||
|
||||
# サーバーエラーが発生した場合のハンドラー。HTTPExceptionではハンドリングできないため、個別に設定
|
||||
app.add_exception_handler(UnexpectedException, http_exception_handler)
|
||||
|
||||
# サーバー起動時のイベント
|
||||
app.add_event_handler('startup', task.create_start_app_handler)
|
||||
|
||||
|
||||
# logger = get_logger(__name__)
|
||||
|
||||
|
||||
# @app.get('/sample/')
|
||||
# def sample(service: BatchStatusService = Depends(get_service(BatchStatusService))):
|
||||
# # import os
|
||||
# # import threading
|
||||
# logger.info('START')
|
||||
# res = service.hdke_table_record
|
||||
# logger.info(res)
|
||||
# logger.info('END')
|
||||
# return res
|
||||
# # return f'{os.getpid()}, {threading.get_ident()}'
|
||||
|
||||
@ -3,15 +3,15 @@ from abc import ABCMeta
|
||||
import pandas as pd
|
||||
from sqlalchemy import text
|
||||
|
||||
from src.db.database import Database
|
||||
from src.db.database import DatabaseClient
|
||||
from src.model.db.base_db_model import BaseDBModel
|
||||
|
||||
|
||||
class BaseRepository(metaclass=ABCMeta):
|
||||
|
||||
_database: Database
|
||||
_database: DatabaseClient
|
||||
|
||||
def __init__(self, db: Database) -> None:
|
||||
def __init__(self, db: DatabaseClient) -> None:
|
||||
self._database = db
|
||||
|
||||
def fetch_all(self) -> list[BaseDBModel]:
|
||||
@ -36,7 +36,7 @@ class BaseRepository(metaclass=ABCMeta):
|
||||
|
||||
sql_query = pd.read_sql(
|
||||
text(query),
|
||||
con=self._database.connection,
|
||||
con=self._database.session.connection(),
|
||||
params=params)
|
||||
|
||||
df = pd.DataFrame(
|
||||
|
||||
@ -5,7 +5,7 @@ from fastapi.exceptions import HTTPException
|
||||
from fastapi.routing import APIRoute
|
||||
from starlette import status
|
||||
|
||||
from src.db.database import Database
|
||||
# from src.db.database import Database
|
||||
from src.depends.auth import (check_session_expired, get_current_session,
|
||||
verify_session)
|
||||
from src.error.exceptions import DBException, UnexpectedException
|
||||
@ -78,16 +78,16 @@ class MeDaCaRoute(APIRoute):
|
||||
return response
|
||||
|
||||
|
||||
class PrepareDatabaseRoute(MeDaCaRoute):
|
||||
"""事前処理として、データベースのエンジンを作成するルートハンドラー
|
||||
Args:
|
||||
MeDaCaRoute (MeDaCaRoute): 共通ルートハンドラー
|
||||
"""
|
||||
async def pre_process_route(self, request: Request):
|
||||
request = await super().pre_process_route(request)
|
||||
# DBエンジンを構築して状態にセット
|
||||
request.app.state._db = Database.get_instance()
|
||||
return request
|
||||
# class PrepareDatabaseRoute(MeDaCaRoute):
|
||||
# """事前処理として、データベースのエンジンを作成するルートハンドラー
|
||||
# Args:
|
||||
# MeDaCaRoute (MeDaCaRoute): 共通ルートハンドラー
|
||||
# """
|
||||
# async def pre_process_route(self, request: Request):
|
||||
# request = await super().pre_process_route(request)
|
||||
# # DBエンジンを構築して状態にセット
|
||||
# request.app.state._db = Database.get_instance()
|
||||
# return request
|
||||
|
||||
|
||||
class BeforeCheckSessionRoute(MeDaCaRoute):
|
||||
@ -112,12 +112,8 @@ class BeforeCheckSessionRoute(MeDaCaRoute):
|
||||
return session_request
|
||||
|
||||
|
||||
class AfterSetCookieSessionRoute(PrepareDatabaseRoute):
|
||||
"""事後処理として、セッションキーをcookieに設定するカスタムルートハンドラー
|
||||
|
||||
Args:
|
||||
PrepareDatabaseRoute (PrepareDatabaseRoute): DBエンジンセットアップルートハンドラー
|
||||
"""
|
||||
class AfterSetCookieSessionRoute(MeDaCaRoute): # (PrepareDatabaseRoute):
|
||||
"""事後処理として、セッションキーをcookieに設定するカスタムルートハンドラー"""
|
||||
async def post_process_route(self, request: Request, response: Response):
|
||||
response = await super().post_process_route(request, response)
|
||||
session_key = response.headers.get('session_key', None)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user