From da656d958253190be9bc01fa5b84ef2719d0b043 Mon Sep 17 00:00:00 2001 From: ksyasuda Date: Thu, 20 Oct 2022 00:05:58 -0700 Subject: [PATCH] update package --- pydb/.gitignore | 1 + pydb/pydb/pydb.py | 41 ++++++++++++++++++++++------------------ pydb/tests/test_mysql.py | 2 +- 3 files changed, 25 insertions(+), 19 deletions(-) diff --git a/pydb/.gitignore b/pydb/.gitignore index 7b733dd..ebc86a6 100644 --- a/pydb/.gitignore +++ b/pydb/.gitignore @@ -2,6 +2,7 @@ __pycache__/ *.py[cod] *$py.class +config.py # C extensions *.so diff --git a/pydb/pydb/pydb.py b/pydb/pydb/pydb.py index 0762162..02440ff 100644 --- a/pydb/pydb/pydb.py +++ b/pydb/pydb/pydb.py @@ -22,10 +22,13 @@ logger.addHandler(handler) def convert_query_result(query_res): """Return List(elements) from query result List(Tuples(element)).""" - return [i[0] if len(query_res) == 1 else list(i) for i in query_res] + if len(query_res) == 0: + logger.warning("Received empty result set from DB") + return [] + return [list(i) for i in query_res] if len(query_res) > 1 else query_res[0][0] -class DB_WRAPPER: +class DbWrapper: """Base Class for DB Connection Wrapper""" def __init__(self, info: DB_INFO, connector, exception=None, dictionary=False): @@ -44,10 +47,9 @@ class DB_WRAPPER: else: self._dictionary = False self._cur = self._conn.cursor() - self._Exception = exception if exception is not None else Exception - except Exception as exception: - logger.critical("Something went wrong connection to DB:", exception) - raise exception + self._exception = exception if exception is not None else Exception + except self._exception as e: + raise e def get_connection(self): """Returns the connection object for the DB""" @@ -67,7 +69,7 @@ class DB_WRAPPER: Returns: [type]: [description] """ - return self._Exception + return self._exception def query(self, stmt): """Queries the db with . Returns list of tuples if there are results.""" @@ -75,7 +77,7 @@ class DB_WRAPPER: self._cur.execute(stmt) res = self._cur.fetchall() return convert_query_result(res) if not self._dictionary else res - except self._Exception as exception: + except self._exception as exception: raise exception def query_with_commit(self, stmt): @@ -83,7 +85,7 @@ class DB_WRAPPER: try: self._cur.execute(stmt) self._cur.commit() - except self._Exception as exception: + except self._exception as exception: raise exception def execute(self, stmt): @@ -96,7 +98,7 @@ class DB_WRAPPER: if not self._dictionary else self._cur.fetchall() ) - except self._Exception as e: + except self._exception as e: raise e def close(self): @@ -105,12 +107,12 @@ class DB_WRAPPER: self._conn.close() -class MysqlDB(DB_WRAPPER): +class MysqlDB(DbWrapper): """Mysql Specific Functions""" def __init__(self, info, dictionary=False): """MySQL Connection Wrapper""" - DB_WRAPPER.__init__( + DbWrapper.__init__( self, info, mysql.connector.connect, @@ -152,7 +154,8 @@ class MysqlDB(DB_WRAPPER): def table_exists(self, schema: str, table: str): stmt = f""" SELECT COUNT(*) from information_schema.TABLES - WHERE TABLE_SCHEMA = '{schema}' and TABLE_NAME = '{table}' + WHERE UPPER(TABLE_SCHEMA) = UPPER('{schema}') + AND UPPER(TABLE_NAME) = UPPER('{table}') """ return self.query(stmt)[0][0] != 0 @@ -196,9 +199,15 @@ class DatabaseManager: def db_factory( db_info: Union[Dict[str, int], Dict[str, str]], db_type: str, dictionary=False ): + """Build the correct db wrapper for the db_type.""" db_type = db_type.strip().lower() if db_type == "mysql": - return MysqlDB(db_info, dictionary=dictionary) + dbo = MysqlDB(db_info, dictionary=dictionary) + else: + logger.error("ERROR %s not valid", db_type) + logger.error("Valid types: [ mysql ]") + sys.exit(1) + return dbo # elif db_type == "snowflake": # return SnowflakeWrapper( # db_info, snowflake.connector.connect, snowflake.connector.Error @@ -211,7 +220,3 @@ def db_factory( # ) # elif db_type == 'mariadb': # return _MariaDB(db_info) - else: - logger.error("ERROR %s not valid", db_type) - logger.error("Valid types: [ mysql ]") - sys.exit(1) diff --git a/pydb/tests/test_mysql.py b/pydb/tests/test_mysql.py index 26a0ea5..f5fa512 100644 --- a/pydb/tests/test_mysql.py +++ b/pydb/tests/test_mysql.py @@ -28,4 +28,4 @@ def test_mysql_2(): db = DatabaseManager(config.MYSQL_INFO, "mysql").__enter__() res = db.query("SELECT COUNT(*) FROM dir_map") logger.info("Result: %s", res) - assert res is not None and res[0] > 0 + assert res is not None and res > 0