import { fetchEventSource } from "@microsoft/fetch-event-source";
import { baseURL, resolveFetch } from "@/api/shared";
import { defaultRetryFunc } from "@/hooks/shared";
import { useAuth } from "@/hooks/shared";
import { QueryConfig } from "@/lib/reactQuery";
import { User } from "@/types";
import { assertUserIsAuthenticated } from "@/utils";
import { useQuery } from "@tanstack/react-query";
import {
  Thread,
  ThreadStreamMessageEvent,
  ThreadState,
  ThreadContext,
  ThreadConfig,
  ThreadStatus,
  ToolCall,
} from "../types/api";
import { useAssistantStore } from "../store/useAssistantStore";
import { useEffect, useState, useRef } from "react";
import { applyPatch } from "fast-json-patch";

export const getThread = async (
  user: User,
  orgId: string,
  threadId: string,
  mode: "org" | "user" = "user",
): Promise<Thread> => {
  const url =
    mode === "org"
      ? `${baseURL}/orgs/${orgId}/assistant-threads/${threadId}`
      : `${baseURL}/orgs/${orgId}/users/current/assistant-threads/${threadId}`;
  return await resolveFetch(
    fetch(url, {
      headers: {
        Authorization: `Bearer ${await user.getIdToken()}`,
      },
    }),
  );
};

type useThreadOptions = {
  orgId: string;
  threadId?: string;
  config?: QueryConfig<typeof getThread>;
  mode: "org" | "user";
};

export const useThread = ({
  orgId,
  threadId,
  config,
  mode = "user",
}: useThreadOptions) => {
  const { user } = useAuth();
  assertUserIsAuthenticated(user);

  return useQuery({
    queryKey: [orgId, "assistant-threads", threadId, mode],
    queryFn: () => {
      if (!threadId) {
        return null;
      }
      return getThread(user, orgId, threadId, mode);
    },
    enabled: !!user && !!threadId,
    retry: defaultRetryFunc,
    refetchInterval: false,
    refetchOnWindowFocus: false,
    ...config,
  });
};

export const streamThread = async (
  user: User,
  orgId: string,
  threadId: string,
  superAdminReviewModeEnabled: boolean,
  onError: () => void,
  onStateUpdate: (state: ThreadState) => void,
  onStatusUpdate: (status: ThreadStatus) => void,
  setCurrentToolCall: (toolCall: ToolCall | null) => void,
): Promise<() => void> => {
  const token = await user.getIdToken();
  const ctrl = new AbortController();

  const thread = await getThread(
    user,
    orgId,
    threadId,
    superAdminReviewModeEnabled ? "org" : "user",
  );

  onStatusUpdate(thread.status);

  let currentState: ThreadState = thread.state ?? {
    context: {} as ThreadContext,
    messages: {},
    message_order: [],
    source_nodes: {},
    title_parts: [],
    config: {} as ThreadConfig,
  };

  onStateUpdate(currentState);

  const url = `${baseURL}/orgs/${orgId}/users/current/assistant-threads/${threadId}/events`;

  let prevMessageId = 0;

  fetchEventSource(url, {
    method: "GET",
    headers: {
      Authorization: `Bearer ${token}`,
    },
    credentials: "include",
    openWhenHidden: true,
    signal: ctrl.signal,
    async onmessage(event) {
      try {
        if (!event.data || event.data.trim() === "") {
          return;
        }

        const message = JSON.parse(event.data) as ThreadStreamMessageEvent;

        if (message.type === "thread_status") {
          onStatusUpdate(message.status);
        }
        if (message.type === "step_start") {
          prevMessageId = message.id;
        }

        if (
          message.type === "step_start" ||
          message.type === "state_patch" ||
          message.type === "step_end"
        ) {
          try {
            // if we skipped a message, trigger reconnect
            if (message.id > prevMessageId + 1) {
              throw new Error("skipped message");
            }
            prevMessageId = message.id;
          } catch {
            ctrl.abort();
            onError();
          }
        }

        if (message.type === "step_start") {
          const lastMessageId =
            message.state.message_order[message.state.message_order.length - 1];
          const lastMessage = message.state.messages[lastMessageId];
          if (lastMessage.tool_calls && lastMessage.tool_calls.length > 0) {
            setCurrentToolCall(lastMessage.tool_calls[0]);
          }
        }

        if (message.type === "state_patch" && message.state_patch) {
          const patch = message.state_patch;

          if (patch.op === "add" && patch.path.includes("/content_parts")) {
            setCurrentToolCall(null);
          }

          try {
            currentState = applyPatch(
              currentState,
              [patch],
              true,
              false,
            ).newDocument;

            onStateUpdate(currentState);
          } catch {
            ctrl.abort();
            onError();
          }
        }
      } catch {
        onError();
      }
    },
    async onopen(response) {
      if (!response.ok || response.status !== 200) {
        onError();
      }
    },
    onclose() {
      onError();
    },
    onerror: onError,
  });

  return () => {
    ctrl.abort();
  };
};

export const useThreadStream = ({
  orgId,
  threadId,
  superAdminReviewModeEnabled,
}: {
  orgId: string;
  threadId?: string;
  superAdminReviewModeEnabled: boolean;
}) => {
  const { user } = useAuth();
  assertUserIsAuthenticated(user);
  const { setThreadStatus } = useAssistantStore();

  const { data: initialThreadData } = useThread({
    orgId,
    threadId,
    mode: superAdminReviewModeEnabled ? "org" : "user",
  });

  const initialState = initialThreadData?.state ?? null;

  const [threadState, setThreadState] = useState<ThreadState | null>(null);
  const [currentToolCall, setCurrentToolCall] = useState<ToolCall | null>(null);
  const connectionRef = useRef<(() => void) | null>(null);
  const isSettingUpStreamRef = useRef(false);
  const processedThreadIdRef = useRef<string | undefined>(undefined);

  useEffect(() => {
    if (threadId !== processedThreadIdRef.current) {
      if (connectionRef.current) {
        connectionRef.current();
        connectionRef.current = null;
      }
    }

    return () => {
      if (connectionRef.current) {
        connectionRef.current();
        connectionRef.current = null;
      }
    };
  }, [threadId]);

  // Main stream setup effect
  useEffect(() => {
    if (!threadId) return;

    if (isSettingUpStreamRef.current || connectionRef.current) {
      return;
    }

    processedThreadIdRef.current = threadId;
    isSettingUpStreamRef.current = true;
    const setupStream = async (currentRetry = 0, maxRetries = 30) => {
      connectionRef.current = await streamThread(
        user,
        orgId,
        threadId,
        superAdminReviewModeEnabled,
        () => {
          if (isSettingUpStreamRef.current) {
            return;
          }

          isSettingUpStreamRef.current = true;

          setThreadStatus("processing");

          setTimeout(() => {
            setupStream(currentRetry + 1, maxRetries);
          }, 500 * currentRetry);
        },
        (newState) => {
          setThreadState((prevState) => {
            if (!prevState) return newState;
            const merged = {
              ...prevState,
              ...newState,
              messages: {
                ...prevState.messages,
                ...newState.messages,
              },
              message_order: Array.from(
                new Set([
                  ...prevState.message_order,
                  ...newState.message_order,
                ]),
              ),
              source_nodes: {
                ...prevState.source_nodes,
                ...newState.source_nodes,
              },
            };
            return merged;
          });
        },
        (status) => {
          setThreadStatus(status);
        },
        (toolCall) => {
          setCurrentToolCall(toolCall);
        },
      );
      isSettingUpStreamRef.current = false;
    };

    setupStream();
  }, [orgId, threadId, user, connectionRef.current]);

  useEffect(() => {
    if (initialState) {
      setThreadState(initialState);
    }
  }, [initialState]);

  useEffect(() => {
    return () => {
      if (connectionRef.current) {
        connectionRef.current();
        connectionRef.current = null;
      }

      setThreadStatus("ready");
    };
  }, []);

  return { threadState, currentToolCall };
};
