tt/src/tt/controller.py

108 lines
3.0 KiB
Python

import json
from datetime import datetime
from peewee import fn
from .db import db, Task, Category, TaskStatus, SavedSearch
def category_lookup(category):
category_id = None
if category:
category, _ = Category.get_or_create(name=category)
return category.id
def add_task(
text: str,
category: str,
status: str = TaskStatus.ZERO.value,
due: datetime | None = None,
type: str = "",
) -> Task:
"""
Add a new task to the database.
Returns the created task instance.
"""
with db.atomic():
category_id = category_lookup(category)
task = Task.create(
text=text, type=type, status=status, due=due, category_id=category_id
)
return task
def update_task(
task_id: int,
**kwargs,
) -> Task:
with db.atomic():
if category := kwargs.pop("category", None):
kwargs["category_id"] = category_lookup(category)
task = Task.get_by_id(task_id)
query = Task.update(kwargs).where(Task.id == task_id)
query.execute()
task = Task.get_by_id(task_id)
return task
def _parse_sort_string(sort_string, model_class):
"""
Convert sort string like 'field1,-field2' to peewee order_by expressions.
"""
sort_expressions = []
if not sort_string:
return sort_expressions
for field in sort_string.split(","):
is_desc = field.startswith("-")
field_name = field[1:] if is_desc else field
# special handling for due_date with COALESCE
if field_name == "due_date":
expr = fn.COALESCE(getattr(model_class, field_name), datetime(3000, 12, 31))
sort_expressions.append(expr.desc() if is_desc else expr)
else:
field_expr = getattr(model_class, field_name)
sort_expressions.append(field_expr.desc() if is_desc else field_expr)
return sort_expressions
def get_tasks(
search_text: str | None = None,
category: int | None = None,
statuses: tuple[str] | None = None,
sort: str = "",
) -> list[Task]:
query = Task.select().where(~Task.deleted)
if search_text:
query = query.where(fn.Lower(Task.text).contains(search_text.lower()))
if category:
query = query.where(Task.category == Category.get(name=category))
if statuses:
query = query.where(Task.status.in_(statuses))
sort_expressions = _parse_sort_string(sort, Task)
query = query.order_by(*sort_expressions)
return list(query)
def get_categories() -> list[Category]:
return list(Category.select().order_by(Category.name))
def save_view(name: str, *, filters: dict, sort_string: str) -> SavedSearch:
filters_json = json.dumps(filters)
return SavedSearch.create(name=name, filters=filters_json, sort_string=sort_string)
def get_saved_view_names() -> list[str]:
return [search.name for search in SavedSearch.select()]
def get_saved_view(name: str) -> SavedSearch:
return SavedSearch.get(SavedSearch.name == name)