search funciton

This commit is contained in:
olari
2021-07-05 14:48:00 +03:00
parent 1e352c7b06
commit 22d1150801

View File

@@ -5,7 +5,6 @@ from pathlib import Path
from shutil import copyfile, rmtree from shutil import copyfile, rmtree
from subprocess import run from subprocess import run
from tempfile import mktemp, mkdtemp from tempfile import mktemp, mkdtemp
from typing import Any, Callable, Optional, Tuple, TypeVar, Union
from zipfile import ZipFile from zipfile import ZipFile
import json import json
import random import random
@@ -20,6 +19,11 @@ JOURNAL_PATH = Path.home() / '.journal.json'
### UTILS ### UTILS
def remove_chars(text, chars):
return ''.join([c for c in text if c not in chars])
def get_words(text):
return remove_chars(text, '.,-:;').lower().split()
def nth_or_default(n, l, default): def nth_or_default(n, l, default):
return l[n] if n < len(l) else default return l[n] if n < len(l) else default
@@ -646,7 +650,7 @@ def parse_entry(timestamp, content):
} }
def generate_entry(entry): def generate_entry(entry):
def format_block(curr: Any, prev: Any, before_prev: Any): def format_block(curr, prev, before_prev):
def format_text(text): def format_text(text):
if all(c == '\n' for c in curr): if all(c == '\n' for c in curr):
return text return text
@@ -697,12 +701,12 @@ def generate_entry(entry):
return result return result
ENTRY_RE = re.compile(r'^(\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}) ?', re.MULTILINE)
def parse_day(text): def parse_day(text):
# discard read-only QS section # discard read-only QS section
text = text[text.find('#'):] text = text[text.find('#'):]
ENTRY_RE = re.compile(r'^(\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}) ?', re.MULTILINE)
header, *tmp = ENTRY_RE.split(text) header, *tmp = ENTRY_RE.split(text)
entries = list(zip(tmp[::2], tmp[1::2])) entries = list(zip(tmp[::2], tmp[1::2]))
@@ -950,6 +954,57 @@ def handle_backup(args):
if prompt('Delete backup archive?'): if prompt('Delete backup archive?'):
archive_path.unlink() archive_path.unlink()
def handle_search(args):
query = args[0]
parts = query.split(',')
strings = []
tags = []
for part in parts:
if part.startswith('#'):
tags.append(part.removeprefix('#'))
else:
strings.append(part)
journal = load_journal()
matches = {}
for day in journal['days']:
for i, entry in enumerate(journal['days'][day]['entries']):
for block in entry['blocks']:
if isinstance(block, str):
words = get_words(block)
if any(s in words for s in strings):
matches[entry['timestamp']] = (day, i)
break
elif block['type'] == 'tag':
if any(t in block['value'] for t in tags):
matches[entry['timestamp']] = (day, i)
break
result = ''
result += f'Num matches: {len(matches)}\n'
result += '---\n'
for day, idx in matches.values():
entry = journal['days'][day]['entries'][idx]
result += generate_entry(entry)
text = edit_text(result)
_, *tmp = ENTRY_RE.split(text)
entries = [parse_entry(ts, c) for ts, c in list(zip(tmp[::2], tmp[1::2]))]
for entry in entries:
day, idx = matches[entry['timestamp']]
journal['days'][day]['entries'][idx] = entry
save_journal(journal)
### MAIN ### MAIN
def main(): def main():
@@ -969,6 +1024,7 @@ def main():
'test': handle_test, 'test': handle_test,
'summary': handle_summary, 'summary': handle_summary,
'backup': handle_backup, 'backup': handle_backup,
'search': handle_search,
} }
handler = command_handlers.get(command, handle_invalid) handler = command_handlers.get(command, handle_invalid)