Chore/python3.4 (#2)
authorShane Jaroch <chown_tee@proton.me>
Mon, 11 Jul 2022 15:46:14 +0000 (11:46 -0400)
committerGitHub <noreply@github.com>
Mon, 11 Jul 2022 15:46:14 +0000 (11:46 -0400)
ntclient/__main__.py
ntclient/persistence/sql/__init__.py
ntclient/persistence/sql/nt/__init__.py
ntclient/persistence/sql/nt/funcs.py
ntclient/persistence/sql/usda/__init__.py
ntclient/persistence/sql/usda/funcs.py

index e500e9c5f0b4b97c4804ef2c7181e8bf5ff798b1..8df1b2ef09170640896606f76c2a4430ebd02701 100644 (file)
@@ -27,7 +27,6 @@ along with this program.  If not, see <https://www.gnu.org/licenses/>.
 import argparse
 import sys
 import time
-from typing import Sequence
 from urllib.error import HTTPError, URLError
 
 import argcomplete
@@ -76,8 +75,12 @@ def build_argparser() -> argparse.ArgumentParser:
     return arg_parser
 
 
-def main(args: Sequence[str] = None) -> int:
-    """Main method for CLI"""
+def main(args: list = None) -> int:
+    """
+    Main method for CLI
+
+    @param args: List[str] | None
+    """
 
     start_time = time.time()
     arg_parser = build_argparser()
index 58d6f4e9be29c31fb68374c65d1779d987a23850..807142d54c20f6a8278e3b8173b995d986e77c99 100644 (file)
@@ -1,23 +1,26 @@
 """Main SQL persistence module, need to rethink circular imports and shared code"""
 import sqlite3
-from typing import Union
 
 # FIXME: maybe just use separate methods for calls with vs. without headers
 #  avoid the mypy headaches, and the liberal comments  # type: ignore
 
 
-def sql_entries(sql_result: sqlite3.Cursor, headers=False) -> Union[list, tuple]:
+def sql_entries(sql_result: sqlite3.Cursor) -> list:
     """Formats and returns a `sql_result()` for console digestion and output"""
     # TODO: return object: metadata, command, status, errors, etc?
     rows = sql_result.fetchall()
 
-    if headers:
-        headers = [x[0] for x in sql_result.description]
-        return headers, rows
-
     return rows
 
 
+def sql_entries_headers(sql_result: sqlite3.Cursor) -> tuple:
+    """Formats and returns a `sql_result()` for console digestion and output"""
+    rows = sql_result.fetchall()
+    headers = [x[0] for x in sql_result.description]
+
+    return headers, rows
+
+
 def version(con: sqlite3.Connection) -> str:
     """Gets the latest entry from version table"""
 
@@ -38,13 +41,11 @@ def close_con_and_cur(
     con.close()
 
 
-def _sql(
-    con: sqlite3.Connection,
-    query: str,
-    db_name: str,
-    values: Union[tuple, list] = None,
-    headers=False,
-) -> Union[list, tuple]:
+def _prep_query(
+    con: sqlite3.Connection, query: str, db_name: str, values=None
+) -> tuple:
+    """@param values: tuple | list | None"""
+
     from ntclient import DEBUG  # pylint: disable=import-outside-toplevel
 
     cur = con.cursor()
@@ -52,19 +53,51 @@ def _sql(
     if DEBUG:
         print("%s.sqlite3: %s" % (db_name, query))
         if values:
-            # TODO: better debug logging, more "control-findable", distinguish from most prints()
+            # TODO: better debug logging, more "control-findable",
+            #  distinguish from most prints()
             print(values)
 
     # TODO: separate `entry` & `entries` entity for single vs. bulk insert?
     if values:
         if isinstance(values, list):
-            rows = cur.executemany(query, values)
+            result = cur.executemany(query, values)
         else:  # tuple
-            rows = cur.execute(query, values)
+            result = cur.execute(query, values)
     else:
-        rows = cur.execute(query)
+        result = cur.execute(query)
+
+    return cur, result
+
+
+def _sql(
+    con: sqlite3.Connection,
+    query: str,
+    db_name: str,
+    values=None,
+) -> list:
+    """@param values: tuple | list | None"""
+
+    cur, result = _prep_query(con, query, db_name, values)
+
+    # TODO: print "<number> SELECTED", or other info
+    #  BASED ON command SELECT/INSERT/DELETE/UPDATE
+    result = sql_entries(cur)
+
+    close_con_and_cur(con, cur)
+    return result
+
+
+def _sql_headers(
+    con: sqlite3.Connection,
+    query: str,
+    db_name: str,
+    values=None,
+) -> tuple:
+    """@param values: tuple | list | None"""
+
+    cur, result = _prep_query(con, query, db_name, values)
+
+    result = sql_entries_headers(result)
 
-    # TODO: print "<number> SELECTED", or other info BASED ON command SELECT/INSERT/DELETE/UPDATE
-    result = sql_entries(rows, headers=headers)
     close_con_and_cur(con, cur)
     return result
index 8dee265749c27f8e9cd1523f6840fd6e4872bb7e..cc1e75043e9741aaec6ca5040c2ff5e41f944a79 100644 (file)
@@ -3,11 +3,11 @@ import os
 import sqlite3
 
 from ntclient import NT_DB_NAME, NUTRA_DIR, __db_target_nt__
-from ntclient.persistence.sql import _sql, version
+from ntclient.persistence.sql import _sql, _sql_headers, version
 from ntclient.utils.exceptions import SqlConnectError, SqlInvalidVersionError
 
 
-def nt_sqlite_connect(version_check=True):
+def nt_sqlite_connect(version_check=True) -> sqlite3.Connection:
     """Connects to the nt.sqlite3 file, or throws an exception"""
     db_path = os.path.join(NUTRA_DIR, NT_DB_NAME)
     if os.path.isfile(db_path):
@@ -29,13 +29,19 @@ def nt_sqlite_connect(version_check=True):
     raise SqlConnectError("ERROR: nt database doesn't exist, please run `nutra init`")
 
 
-def nt_ver():
+def nt_ver() -> str:
     """Gets version string for nt.sqlite3 database"""
     con = nt_sqlite_connect(version_check=False)
     return version(con)
 
 
-def sql(query, values=None, headers=False):
+def sql(query, values=None) -> list:
     """Executes a SQL command to nt.sqlite3"""
     con = nt_sqlite_connect()
-    return _sql(con, query, db_name="nt", values=values, headers=headers)
+    return _sql(con, query, db_name="nt", values=values)
+
+
+def sql_headers(query, values=None) -> tuple:
+    """Executes a SQL command to nt.sqlite3"""
+    con = nt_sqlite_connect()
+    return _sql_headers(con, query, db_name="nt", values=values)
index 7c042afe8d71e9877de14ba6ba5e0c07f8a43779..bee8fc6fdced07ad16d497c4467f0794dda8e014 100644 (file)
@@ -1,5 +1,5 @@
 """nt.sqlite3 functions module"""
-from ntclient.persistence.sql.nt import sql
+from ntclient.persistence.sql.nt import sql, sql_headers
 
 
 def sql_nt_next_index(table=None):
@@ -33,7 +33,7 @@ FROM
 GROUP BY
   id;
 """
-    return sql(query, headers=True)
+    return sql_headers(query)
 
 
 def sql_analyze_recipe(recipe_id):
index a2f27aae39506d979fc6b495bee551bd3f0bfeca..4c3df950b32df9915f09877f74a783046cc76025 100644 (file)
@@ -3,10 +3,9 @@ import os
 import sqlite3
 import tarfile
 import urllib.request
-from typing import Union
 
 from ntclient import NUTRA_DIR, USDA_DB_NAME, __db_target_usda__
-from ntclient.persistence.sql import _sql, version
+from ntclient.persistence.sql import _sql, _sql_headers, version
 from ntclient.utils.exceptions import SqlConnectError, SqlInvalidVersionError
 
 
@@ -92,10 +91,19 @@ def usda_ver() -> str:
     return version(con)
 
 
-def sql(query, values=None, headers=False, version_check=True) -> Union[list, tuple]:
+def sql(query, values=None, version_check=True) -> list:
     """Executes a SQL command to usda.sqlite3"""
 
     con = usda_sqlite_connect(version_check=version_check)
 
     # TODO: support argument: _sql(..., params=params, ...)
-    return _sql(con, query, db_name="usda", values=values, headers=headers)
+    return _sql(con, query, db_name="usda", values=values)
+
+
+def sql_headers(query, values=None, version_check=True) -> tuple:
+    """Executes a SQL command to usda.sqlite3 [WITH HEADERS]"""
+
+    con = usda_sqlite_connect(version_check=version_check)
+
+    # TODO: support argument: _sql(..., params=params, ...)
+    return _sql_headers(con, query, db_name="usda", values=values)
index ff21588863e4d31bc02b07d1c4b3613459289420..00bdd61436db6424d05c6b8d7eff9bd433fd5b8a 100644 (file)
@@ -1,5 +1,5 @@
 """usda.sqlite functions module"""
-from ntclient.persistence.sql.usda import sql
+from ntclient.persistence.sql.usda import sql, sql_headers
 from ntclient.utils import NUTR_ID_KCAL
 
 
@@ -25,7 +25,7 @@ def sql_food_details(food_ids=None) -> list:
         food_ids = ",".join(str(x) for x in set(food_ids))
         query = query % food_ids
 
-    return sql(query)  # type: ignore
+    return sql(query)
 
 
 def sql_nutrients_overview() -> dict:
@@ -40,7 +40,7 @@ def sql_nutrients_details() -> tuple:
     """Shows nutrients 'details'"""
 
     query = "SELECT * FROM nutrients_overview;"
-    return sql(query, headers=True)  # type: ignore
+    return sql_headers(query)
 
 
 def sql_servings(food_ids) -> list:
@@ -59,7 +59,7 @@ WHERE
   serv.food_id IN (%s);
 """
     food_ids = ",".join(str(x) for x in set(food_ids))
-    return sql(query % food_ids)  # type: ignore
+    return sql(query % food_ids)
 
 
 def sql_analyze_foods(food_ids) -> list:
@@ -77,7 +77,7 @@ WHERE
 """
     # TODO: parameterized queries
     food_ids = ",".join(str(x) for x in set(food_ids))
-    return sql(query % food_ids)  # type: ignore
+    return sql(query % food_ids)
 
 
 ################################################################################
@@ -100,7 +100,7 @@ ORDER BY
   food_id;
 """
 
-    return sql(query % (NUTR_ID_KCAL, nutrient_id))  # type: ignore
+    return sql(query % (NUTR_ID_KCAL, nutrient_id))
 
 
 def sql_sort_foods(nutr_id) -> list:
@@ -126,7 +126,7 @@ ORDER BY
   nut_data.nutr_val DESC;
 """
 
-    return sql(query % nutr_id)  # type: ignore
+    return sql(query % nutr_id)
 
 
 def sql_sort_foods_by_kcal(nutr_id) -> list:
@@ -155,4 +155,4 @@ ORDER BY
   (nut_data.nutr_val / kcal.nutr_val) DESC;
 """
 
-    return sql(query % nutr_id)  # type: ignore
+    return sql(query % nutr_id)