update package
This commit is contained in:
parent
1869ed4057
commit
da656d9582
1
pydb/.gitignore
vendored
1
pydb/.gitignore
vendored
@ -2,6 +2,7 @@
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
config.py
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
@ -22,10 +22,13 @@ logger.addHandler(handler)
|
||||
|
||||
def convert_query_result(query_res):
|
||||
"""Return List(elements) from query result List(Tuples(element))."""
|
||||
return [i[0] if len(query_res) == 1 else list(i) for i in query_res]
|
||||
if len(query_res) == 0:
|
||||
logger.warning("Received empty result set from DB")
|
||||
return []
|
||||
return [list(i) for i in query_res] if len(query_res) > 1 else query_res[0][0]
|
||||
|
||||
|
||||
class DB_WRAPPER:
|
||||
class DbWrapper:
|
||||
"""Base Class for DB Connection Wrapper"""
|
||||
|
||||
def __init__(self, info: DB_INFO, connector, exception=None, dictionary=False):
|
||||
@ -44,10 +47,9 @@ class DB_WRAPPER:
|
||||
else:
|
||||
self._dictionary = False
|
||||
self._cur = self._conn.cursor()
|
||||
self._Exception = exception if exception is not None else Exception
|
||||
except Exception as exception:
|
||||
logger.critical("Something went wrong connection to DB:", exception)
|
||||
raise exception
|
||||
self._exception = exception if exception is not None else Exception
|
||||
except self._exception as e:
|
||||
raise e
|
||||
|
||||
def get_connection(self):
|
||||
"""Returns the connection object for the DB"""
|
||||
@ -67,7 +69,7 @@ class DB_WRAPPER:
|
||||
Returns:
|
||||
[type]: [description]
|
||||
"""
|
||||
return self._Exception
|
||||
return self._exception
|
||||
|
||||
def query(self, stmt):
|
||||
"""Queries the db with <stmt>. Returns list of tuples if there are results."""
|
||||
@ -75,7 +77,7 @@ class DB_WRAPPER:
|
||||
self._cur.execute(stmt)
|
||||
res = self._cur.fetchall()
|
||||
return convert_query_result(res) if not self._dictionary else res
|
||||
except self._Exception as exception:
|
||||
except self._exception as exception:
|
||||
raise exception
|
||||
|
||||
def query_with_commit(self, stmt):
|
||||
@ -83,7 +85,7 @@ class DB_WRAPPER:
|
||||
try:
|
||||
self._cur.execute(stmt)
|
||||
self._cur.commit()
|
||||
except self._Exception as exception:
|
||||
except self._exception as exception:
|
||||
raise exception
|
||||
|
||||
def execute(self, stmt):
|
||||
@ -96,7 +98,7 @@ class DB_WRAPPER:
|
||||
if not self._dictionary
|
||||
else self._cur.fetchall()
|
||||
)
|
||||
except self._Exception as e:
|
||||
except self._exception as e:
|
||||
raise e
|
||||
|
||||
def close(self):
|
||||
@ -105,12 +107,12 @@ class DB_WRAPPER:
|
||||
self._conn.close()
|
||||
|
||||
|
||||
class MysqlDB(DB_WRAPPER):
|
||||
class MysqlDB(DbWrapper):
|
||||
"""Mysql Specific Functions"""
|
||||
|
||||
def __init__(self, info, dictionary=False):
|
||||
"""MySQL Connection Wrapper"""
|
||||
DB_WRAPPER.__init__(
|
||||
DbWrapper.__init__(
|
||||
self,
|
||||
info,
|
||||
mysql.connector.connect,
|
||||
@ -152,7 +154,8 @@ class MysqlDB(DB_WRAPPER):
|
||||
def table_exists(self, schema: str, table: str):
|
||||
stmt = f"""
|
||||
SELECT COUNT(*) from information_schema.TABLES
|
||||
WHERE TABLE_SCHEMA = '{schema}' and TABLE_NAME = '{table}'
|
||||
WHERE UPPER(TABLE_SCHEMA) = UPPER('{schema}')
|
||||
AND UPPER(TABLE_NAME) = UPPER('{table}')
|
||||
"""
|
||||
return self.query(stmt)[0][0] != 0
|
||||
|
||||
@ -196,9 +199,15 @@ class DatabaseManager:
|
||||
def db_factory(
|
||||
db_info: Union[Dict[str, int], Dict[str, str]], db_type: str, dictionary=False
|
||||
):
|
||||
"""Build the correct db wrapper for the db_type."""
|
||||
db_type = db_type.strip().lower()
|
||||
if db_type == "mysql":
|
||||
return MysqlDB(db_info, dictionary=dictionary)
|
||||
dbo = MysqlDB(db_info, dictionary=dictionary)
|
||||
else:
|
||||
logger.error("ERROR %s not valid", db_type)
|
||||
logger.error("Valid types: [ mysql ]")
|
||||
sys.exit(1)
|
||||
return dbo
|
||||
# elif db_type == "snowflake":
|
||||
# return SnowflakeWrapper(
|
||||
# db_info, snowflake.connector.connect, snowflake.connector.Error
|
||||
@ -211,7 +220,3 @@ def db_factory(
|
||||
# )
|
||||
# elif db_type == 'mariadb':
|
||||
# return _MariaDB(db_info)
|
||||
else:
|
||||
logger.error("ERROR %s not valid", db_type)
|
||||
logger.error("Valid types: [ mysql ]")
|
||||
sys.exit(1)
|
||||
|
@ -28,4 +28,4 @@ def test_mysql_2():
|
||||
db = DatabaseManager(config.MYSQL_INFO, "mysql").__enter__()
|
||||
res = db.query("SELECT COUNT(*) FROM dir_map")
|
||||
logger.info("Result: %s", res)
|
||||
assert res is not None and res[0] > 0
|
||||
assert res is not None and res > 0
|
||||
|
Loading…
Reference in New Issue
Block a user