Source code for streamlitextras.threader

import os
import time
import inspect
import threading
from collections.abc import Callable, Iterable, Mapping

from streamlit.runtime.scriptrunner import (
    add_script_run_ctx,
    get_script_run_ctx,
    RerunException,
    ScriptRunContext,
)
from typing import Any, Callable, Optional

# script_dir = os.path.dirname(os.path.realpath(__file__))
script_dir = os.getcwd()

default_id = "default"


[docs]def trigger_file_path(unique_id: str = default_id) -> str: trigger_file_path = os.path.join(script_dir, f"reruntrigger_{unique_id}.py") if not os.path.exists(trigger_file_path): with open(trigger_file_path, "w") as f: f.write(f"timestamp = {time.time()}") return trigger_file_path
lock = threading.Lock()
[docs]def last_trigger_time(unique_id: str = default_id) -> int: """ Returns the seconds since last writing the trigger file """ this_trigger_file_path = trigger_file_path(unique_id) if not os.path.exists(this_trigger_file_path): return 9999 modified_time = os.path.getmtime(this_trigger_file_path) modified_time_seconds = time.time() - modified_time return modified_time_seconds
[docs]def trigger_rerun( unique_id: str = default_id, last_write_margin: int = 1, delay: int = 0 ) -> None: """ Triggers treamlit to rerun the current page state. runOnSave must be set to true in config.toml :param str unique_id: Unique ID to be triggered, should be set per session e.g. user id or a hash you create in their session state. :param int last_write_margin: If the file was modified less than this many seconds ago, the rerun will not be performed :param int delay: sleep for this many seconds before writing the rerun trigger """ if delay: time.sleep(delay) with lock: modified_time_seconds = last_trigger_time(unique_id) if last_write_margin == 0 or modified_time_seconds > last_write_margin: frame = inspect.currentframe() caller = frame.f_back.f_code.co_name caller_caller = frame.f_back.f_back.f_code.co_name trigger_file = trigger_file_path(unique_id) print( "Writing trigger", trigger_file, f"from `{caller_caller}.{caller}`", flush=True, ) with open(trigger_file, "w") as f: f.write(f"timestamp = {time.time()}")
# https://github.com/streamlit/streamlit/issues/1792 # https://discuss.streamlit.io/t/using-streamlit-with-multithreading/30990 # https://discuss.streamlit.io/t/how-to-run-a-subprocess-programs-using-thread-inside-streamlit/2440/2 # https://discuss.streamlit.io/t/how-to-monitor-the-filesystem-and-have-streamlit-updated-when-some-files-are-modified/822/3
[docs]def thread_wrapper( thread_func, rerun_st=True, last_write_margin: int = 1, delay: int = 0, trigger_unique_id: str = default_id, *args, **kwargs, ) -> None: """ Wrapper for running thread functions For parameters see streamlit_thread() and trigger_rerun() """ # print("Hashseed in thread:", os.environ.get("PYTHONHASHSEED", False)) thread_func(*args, **kwargs) if rerun_st is True: trigger_rerun(trigger_unique_id, last_write_margin, delay)
[docs]def streamlit_thread( thread_func: Callable, args: tuple = (), kwargs: dict = {}, rerun_st: bool = True, last_write_margin: int = 1, delay: int = 0, script_run_context: ScriptRunContext | None = None, autostart: bool = True, trigger_unique_id: str = default_id, error_handler: Callable | None = None, ) -> str: """ Spawns and starts a threading.Thread that runs thread_func with the passed args and kwargs :param Callable thread_func: The function to run in the thread :param tuple args: The args to pass to the function in the thread :param dict kwargs: The kwargs to pass to the function in the thread :param bool rerun_st: Whether to rerun streamlit after the thread function finishes :param Callable error_handler: Error handler function that takes the thread exception as an argument :returns: The name of the thread. Can use get_thread to get the threading.Thread instance """ # print("Thread entry hashseed:", os.environ.get("PYTHONHASHSEED", False)) args = (thread_func, rerun_st, last_write_margin, delay, trigger_unique_id, *args) thread = PropagatingThread( target=thread_wrapper, error_handler=error_handler, args=args, kwargs=kwargs ) if not script_run_context: script_run_context = get_script_run_ctx() add_script_run_ctx(thread, script_run_context) time.sleep(0.4) if autostart is True: thread.start() return thread.name
[docs]def get_thread(thread_name) -> Optional[threading.Thread]: """ Gets the threading.Thread instance thats name attribute matches thread_name :param thread_name: The name attribute of the thread to look for. :returns: The threading.Thread or None if theres no thread with the supplied thread_name """ threads = threading.enumerate() target_thread = None for thread in threads: if thread.name == thread_name: target_thread = thread break return target_thread
[docs]class PropagatingThread(threading.Thread): def __init__(self, *args, **kwargs) -> None: self.error_handler = kwargs.get("error_handler", None) del kwargs["error_handler"] super().__init__(*args, **kwargs)
[docs] def run(self): self.exc = None try: self.ret = self._target(*self._args, **self._kwargs) except RerunException as e: self.exc = e except BaseException as e: self.exc = e if self.error_handler and callable(self.error_handler): self.error_handler(e) else: raise
[docs] def join(self, timeout=None): super(PropagatingThread, self).join(timeout) if self.exc: if self.error_handler and callable(self.error_handler): self.error_handler(self.exc) else: raise self.exc # raise RuntimeError('Exception in thread') from self.exc return self.ret