108 lines
3.0 KiB
Python
108 lines
3.0 KiB
Python
import json
|
|
from datetime import datetime
|
|
from peewee import fn
|
|
from .db import db, Task, Category, TaskStatus, SavedSearch
|
|
|
|
|
|
def category_lookup(category):
|
|
category_id = None
|
|
if category:
|
|
category, _ = Category.get_or_create(name=category)
|
|
return category.id
|
|
|
|
|
|
def add_task(
|
|
text: str,
|
|
category: str,
|
|
status: str = TaskStatus.ZERO.value,
|
|
due: datetime | None = None,
|
|
type: str = "",
|
|
) -> Task:
|
|
"""
|
|
Add a new task to the database.
|
|
Returns the created task instance.
|
|
"""
|
|
with db.atomic():
|
|
category_id = category_lookup(category)
|
|
task = Task.create(
|
|
text=text, type=type, status=status, due=due, category_id=category_id
|
|
)
|
|
return task
|
|
|
|
|
|
def update_task(
|
|
task_id: int,
|
|
**kwargs,
|
|
) -> Task:
|
|
with db.atomic():
|
|
if category := kwargs.pop("category", None):
|
|
kwargs["category_id"] = category_lookup(category)
|
|
task = Task.get_by_id(task_id)
|
|
query = Task.update(kwargs).where(Task.id == task_id)
|
|
query.execute()
|
|
task = Task.get_by_id(task_id)
|
|
return task
|
|
|
|
|
|
def _parse_sort_string(sort_string, model_class):
|
|
"""
|
|
Convert sort string like 'field1,-field2' to peewee order_by expressions.
|
|
"""
|
|
sort_expressions = []
|
|
|
|
if not sort_string:
|
|
return sort_expressions
|
|
|
|
for field in sort_string.split(","):
|
|
is_desc = field.startswith("-")
|
|
field_name = field[1:] if is_desc else field
|
|
|
|
# special handling for due_date with COALESCE
|
|
if field_name == "due_date":
|
|
expr = fn.COALESCE(getattr(model_class, field_name), datetime(3000, 12, 31))
|
|
sort_expressions.append(expr.desc() if is_desc else expr)
|
|
else:
|
|
field_expr = getattr(model_class, field_name)
|
|
sort_expressions.append(field_expr.desc() if is_desc else field_expr)
|
|
|
|
return sort_expressions
|
|
|
|
|
|
def get_tasks(
|
|
search_text: str | None = None,
|
|
category: int | None = None,
|
|
statuses: tuple[str] | None = None,
|
|
sort: str = "",
|
|
) -> list[Task]:
|
|
query = Task.select().where(~Task.deleted)
|
|
|
|
if search_text:
|
|
query = query.where(fn.Lower(Task.text).contains(search_text.lower()))
|
|
if category:
|
|
query = query.where(Task.category == Category.get(name=category))
|
|
if statuses:
|
|
query = query.where(Task.status.in_(statuses))
|
|
|
|
sort_expressions = _parse_sort_string(sort, Task)
|
|
query = query.order_by(*sort_expressions)
|
|
|
|
return list(query)
|
|
|
|
|
|
def get_categories() -> list[Category]:
|
|
return list(Category.select().order_by(Category.name))
|
|
|
|
|
|
def save_view(name: str, *, filters: dict, sort_string: str) -> SavedSearch:
|
|
filters_json = json.dumps(filters)
|
|
|
|
return SavedSearch.create(name=name, filters=filters_json, sort_string=sort_string)
|
|
|
|
|
|
def get_saved_view_names() -> list[str]:
|
|
return [search.name for search in SavedSearch.select()]
|
|
|
|
|
|
def get_saved_view(name: str) -> SavedSearch:
|
|
return SavedSearch.get(SavedSearch.name == name)
|