347 lines
9.4 KiB
Python
347 lines
9.4 KiB
Python
|
#!/usr/bin/env python
|
|||
|
# -*- coding: utf-8 -*-
|
|||
|
|
|||
|
"""
|
|||
|
数据库连接配置模块
|
|||
|
包含数据库连接参数和SQLAlchemy配置,以及Redis缓存配置
|
|||
|
"""
|
|||
|
|
|||
|
import os
|
|||
|
import json
|
|||
|
from sqlalchemy import create_engine
|
|||
|
from sqlalchemy.ext.declarative import declarative_base
|
|||
|
from sqlalchemy.orm import sessionmaker, scoped_session
|
|||
|
|
|||
|
class ConfigDict:
|
|||
|
"""配置字典类,支持通过点号访问配置项"""
|
|||
|
def __init__(self, **kwargs):
|
|||
|
for key, value in kwargs.items():
|
|||
|
if isinstance(value, dict):
|
|||
|
setattr(self, key, ConfigDict(**value))
|
|||
|
else:
|
|||
|
setattr(self, key, value)
|
|||
|
|
|||
|
def get(self, key, default=None):
|
|||
|
return getattr(self, key, default)
|
|||
|
|
|||
|
def to_dict(self):
|
|||
|
result = {}
|
|||
|
for key, value in self.__dict__.items():
|
|||
|
if isinstance(value, ConfigDict):
|
|||
|
result[key] = value.to_dict()
|
|||
|
else:
|
|||
|
result[key] = value
|
|||
|
return result
|
|||
|
|
|||
|
# 数据库连接配置
|
|||
|
DB_CONFIG = ConfigDict(
|
|||
|
default=dict(
|
|||
|
dialect='mysql',
|
|||
|
driver='pymysql',
|
|||
|
username='root',
|
|||
|
password='root',
|
|||
|
host='localhost',
|
|||
|
port=3306,
|
|||
|
database='tianfeng_task',
|
|||
|
charset='utf8mb4'
|
|||
|
),
|
|||
|
test=dict(
|
|||
|
dialect='sqlite',
|
|||
|
database=':memory:'
|
|||
|
)
|
|||
|
)
|
|||
|
|
|||
|
# Redis缓存配置
|
|||
|
REDIS_CONFIG = ConfigDict(
|
|||
|
default=dict(
|
|||
|
host='localhost',
|
|||
|
port=6379,
|
|||
|
db=0,
|
|||
|
password=None,
|
|||
|
prefix='tianfeng:',
|
|||
|
socket_timeout=5,
|
|||
|
socket_connect_timeout=5,
|
|||
|
decode_responses=True
|
|||
|
),
|
|||
|
test=dict(
|
|||
|
host='localhost',
|
|||
|
port=6379,
|
|||
|
db=1,
|
|||
|
password=None,
|
|||
|
prefix='tianfeng_test:',
|
|||
|
decode_responses=True
|
|||
|
)
|
|||
|
)
|
|||
|
|
|||
|
# 当前环境,可通过环境变量设置
|
|||
|
ENV = os.environ.get('TIANFENG_ENV', 'default')
|
|||
|
|
|||
|
# 根据环境获取数据库配置
|
|||
|
db_conf = getattr(DB_CONFIG, ENV)
|
|||
|
|
|||
|
# 构建数据库连接URL
|
|||
|
if db_conf.dialect == 'sqlite':
|
|||
|
DATABASE_URL = f"{db_conf.dialect}:///{db_conf.database}"
|
|||
|
else:
|
|||
|
DATABASE_URL = (
|
|||
|
f"{db_conf.dialect}+{db_conf.driver}://"
|
|||
|
f"{db_conf.username}:{db_conf.password}@"
|
|||
|
f"{db_conf.host}:{db_conf.port}/{db_conf.database}?"
|
|||
|
f"charset={db_conf.charset}"
|
|||
|
)
|
|||
|
|
|||
|
# 创建数据库引擎
|
|||
|
engine = create_engine(
|
|||
|
DATABASE_URL,
|
|||
|
pool_size=20,
|
|||
|
max_overflow=0,
|
|||
|
pool_recycle=3600,
|
|||
|
pool_pre_ping=True,
|
|||
|
echo=False # 设置为True可以显示SQL语句,用于调试
|
|||
|
)
|
|||
|
|
|||
|
# 创建会话工厂
|
|||
|
SessionFactory = sessionmaker(bind=engine)
|
|||
|
|
|||
|
# 创建线程安全的会话
|
|||
|
db_session = scoped_session(SessionFactory)
|
|||
|
|
|||
|
# 创建基类
|
|||
|
Base = declarative_base()
|
|||
|
Base.query = db_session.query_property()
|
|||
|
|
|||
|
# 数据库配置类
|
|||
|
class DBConfig:
|
|||
|
"""数据库配置类,提供数据库相关的配置和方法"""
|
|||
|
config = DB_CONFIG
|
|||
|
env = ENV
|
|||
|
url = DATABASE_URL
|
|||
|
engine = engine
|
|||
|
session = db_session
|
|||
|
base = Base
|
|||
|
|
|||
|
@classmethod
|
|||
|
def get_config(cls):
|
|||
|
"""获取当前环境的数据库配置"""
|
|||
|
return getattr(cls.config, cls.env)
|
|||
|
|
|||
|
@classmethod
|
|||
|
def get_session(cls):
|
|||
|
"""获取数据库会话"""
|
|||
|
return cls.session
|
|||
|
|
|||
|
@classmethod
|
|||
|
def init_db(cls):
|
|||
|
"""
|
|||
|
初始化数据库
|
|||
|
创建所有表
|
|||
|
"""
|
|||
|
# 导入所有模型,确保它们已注册到Base
|
|||
|
import data.models
|
|||
|
|
|||
|
# 首先尝试创建数据库(如果不存在)
|
|||
|
if cls.get_config().dialect != 'sqlite':
|
|||
|
from sqlalchemy import text
|
|||
|
# 创建一个不指定数据库的连接
|
|||
|
db_conf = cls.get_config()
|
|||
|
temp_url = (
|
|||
|
f"{db_conf.dialect}+{db_conf.driver}://"
|
|||
|
f"{db_conf.username}:{db_conf.password}@"
|
|||
|
f"{db_conf.host}:{db_conf.port}/"
|
|||
|
f"?charset={db_conf.charset}"
|
|||
|
)
|
|||
|
temp_engine = create_engine(temp_url)
|
|||
|
with temp_engine.connect() as conn:
|
|||
|
conn.execute(text(f"CREATE DATABASE IF NOT EXISTS {db_conf.database} CHARACTER SET {db_conf.charset} COLLATE {db_conf.charset}_unicode_ci;"))
|
|||
|
conn.commit()
|
|||
|
temp_engine.dispose()
|
|||
|
|
|||
|
# 创建所有表
|
|||
|
cls.base.metadata.create_all(bind=cls.engine)
|
|||
|
|
|||
|
@classmethod
|
|||
|
def shutdown_session(cls, exception=None):
|
|||
|
"""
|
|||
|
关闭会话
|
|||
|
在应用程序关闭时调用
|
|||
|
"""
|
|||
|
cls.session.remove()
|
|||
|
|
|||
|
# 缓存配置类
|
|||
|
class CacheConfig:
|
|||
|
"""缓存配置类,提供Redis缓存相关的配置和方法"""
|
|||
|
config = REDIS_CONFIG
|
|||
|
env = ENV
|
|||
|
_redis_client = None
|
|||
|
|
|||
|
@classmethod
|
|||
|
def get_config(cls):
|
|||
|
"""获取当前环境的Redis配置"""
|
|||
|
return getattr(cls.config, cls.env)
|
|||
|
|
|||
|
@classmethod
|
|||
|
def get_redis_client(cls):
|
|||
|
"""获取Redis客户端实例"""
|
|||
|
if cls._redis_client is None:
|
|||
|
try:
|
|||
|
import redis
|
|||
|
redis_conf = cls.get_config()
|
|||
|
cls._redis_client = redis.Redis(
|
|||
|
host=redis_conf.host,
|
|||
|
port=redis_conf.port,
|
|||
|
db=redis_conf.db,
|
|||
|
password=redis_conf.password,
|
|||
|
socket_timeout=getattr(redis_conf, 'socket_timeout', 5),
|
|||
|
socket_connect_timeout=getattr(redis_conf, 'socket_connect_timeout', 5),
|
|||
|
decode_responses=getattr(redis_conf, 'decode_responses', True)
|
|||
|
)
|
|||
|
except ImportError:
|
|||
|
raise ImportError("Redis package is not installed. Please install it with 'pip install redis'")
|
|||
|
except Exception as e:
|
|||
|
print(f"Error connecting to Redis: {e}")
|
|||
|
return None
|
|||
|
return cls._redis_client
|
|||
|
|
|||
|
@classmethod
|
|||
|
def get_key(cls, key):
|
|||
|
"""获取带前缀的缓存键"""
|
|||
|
prefix = getattr(cls.get_config(), 'prefix', 'tianfeng:')
|
|||
|
return f"{prefix}{key}"
|
|||
|
|
|||
|
@classmethod
|
|||
|
def set(cls, key, value, expire=None):
|
|||
|
"""
|
|||
|
设置缓存
|
|||
|
|
|||
|
Args:
|
|||
|
key (str): 缓存键
|
|||
|
value (any): 缓存值,非字符串类型会被JSON序列化
|
|||
|
expire (int, optional): 过期时间(秒)
|
|||
|
|
|||
|
Returns:
|
|||
|
bool: 是否设置成功
|
|||
|
"""
|
|||
|
redis_client = cls.get_redis_client()
|
|||
|
if not redis_client:
|
|||
|
return False
|
|||
|
|
|||
|
if not isinstance(value, (str, int, float, bool)):
|
|||
|
value = json.dumps(value)
|
|||
|
|
|||
|
full_key = cls.get_key(key)
|
|||
|
if expire:
|
|||
|
return redis_client.setex(full_key, expire, value)
|
|||
|
else:
|
|||
|
return redis_client.set(full_key, value)
|
|||
|
|
|||
|
@classmethod
|
|||
|
def get(cls, key, default=None):
|
|||
|
"""
|
|||
|
获取缓存
|
|||
|
|
|||
|
Args:
|
|||
|
key (str): 缓存键
|
|||
|
default (any, optional): 默认值
|
|||
|
|
|||
|
Returns:
|
|||
|
any: 缓存值或默认值
|
|||
|
"""
|
|||
|
redis_client = cls.get_redis_client()
|
|||
|
if not redis_client:
|
|||
|
return default
|
|||
|
|
|||
|
full_key = cls.get_key(key)
|
|||
|
value = redis_client.get(full_key)
|
|||
|
|
|||
|
if value is None:
|
|||
|
return default
|
|||
|
|
|||
|
# 尝试解析JSON
|
|||
|
try:
|
|||
|
if value.startswith('{') or value.startswith('['):
|
|||
|
return json.loads(value)
|
|||
|
except (json.JSONDecodeError, AttributeError):
|
|||
|
pass
|
|||
|
|
|||
|
return value
|
|||
|
|
|||
|
@classmethod
|
|||
|
def delete(cls, key):
|
|||
|
"""
|
|||
|
删除缓存
|
|||
|
|
|||
|
Args:
|
|||
|
key (str): 缓存键
|
|||
|
|
|||
|
Returns:
|
|||
|
bool: 是否删除成功
|
|||
|
"""
|
|||
|
redis_client = cls.get_redis_client()
|
|||
|
if not redis_client:
|
|||
|
return False
|
|||
|
|
|||
|
full_key = cls.get_key(key)
|
|||
|
return redis_client.delete(full_key) > 0
|
|||
|
|
|||
|
@classmethod
|
|||
|
def exists(cls, key):
|
|||
|
"""
|
|||
|
检查缓存是否存在
|
|||
|
|
|||
|
Args:
|
|||
|
key (str): 缓存键
|
|||
|
|
|||
|
Returns:
|
|||
|
bool: 是否存在
|
|||
|
"""
|
|||
|
redis_client = cls.get_redis_client()
|
|||
|
if not redis_client:
|
|||
|
return False
|
|||
|
|
|||
|
full_key = cls.get_key(key)
|
|||
|
return redis_client.exists(full_key) > 0
|
|||
|
|
|||
|
@classmethod
|
|||
|
def ttl(cls, key):
|
|||
|
"""
|
|||
|
获取缓存剩余过期时间
|
|||
|
|
|||
|
Args:
|
|||
|
key (str): 缓存键
|
|||
|
|
|||
|
Returns:
|
|||
|
int: 剩余秒数,-1表示永不过期,-2表示不存在
|
|||
|
"""
|
|||
|
redis_client = cls.get_redis_client()
|
|||
|
if not redis_client:
|
|||
|
return -2
|
|||
|
|
|||
|
full_key = cls.get_key(key)
|
|||
|
return redis_client.ttl(full_key)
|
|||
|
|
|||
|
@classmethod
|
|||
|
def clear_all(cls):
|
|||
|
"""
|
|||
|
清除当前环境下的所有缓存
|
|||
|
|
|||
|
Returns:
|
|||
|
bool: 是否清除成功
|
|||
|
"""
|
|||
|
redis_client = cls.get_redis_client()
|
|||
|
if not redis_client:
|
|||
|
return False
|
|||
|
|
|||
|
prefix = getattr(cls.get_config(), 'prefix', 'tianfeng:')
|
|||
|
keys = redis_client.keys(f"{prefix}*")
|
|||
|
if keys:
|
|||
|
return redis_client.delete(*keys) > 0
|
|||
|
return True
|
|||
|
|
|||
|
# 兼容旧代码的函数
|
|||
|
def init_db():
|
|||
|
"""初始化数据库(兼容旧代码)"""
|
|||
|
DBConfig.init_db()
|
|||
|
|
|||
|
def shutdown_session(exception=None):
|
|||
|
"""关闭会话(兼容旧代码)"""
|
|||
|
DBConfig.shutdown_session(exception)
|