/*
** Job Arranger for ZABBIX
** Copyright (C) 2025 Daiwa Institute of Research Ltd. All Rights Reserved.
**
** This program is free software; you can redistribute it and/or modify
** it under the terms of the GNU General Public License as published by
** the Free Software Foundation; either version 2 of the License, or
** (at your option) any later version.
**
** This program is distributed in the hope that it will be useful,
** but WITHOUT ANY WARRANTY; without even the implied warranty of
** MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
** GNU General Public License for more details.
**
** You should have received a copy of the GNU General Public License
** along with this program; if not, write to the Free Software
** Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
**/

package main

/*
#cgo CFLAGS: -I/usr/include
#cgo LDFLAGS: -lvterm
#include <vterm.h>
#include "vterm_helper.h"
#include <stdlib.h>
#include <locale.h>
#include <stdio.h>
#include <string.h>
*/
import "C"

import (
	"bufio"
	"bytes"
	"encoding/binary"
	"encoding/json"
	"fmt"
	"io"
	clientcommon "jobarranger2/src/jobarg_server/managers/icon_exec_manager/workers/common"
	"jobarranger2/src/libs/golibs/common"
	"log"
	"net"
	"os"
	"path/filepath"
	"runtime/debug"
	"strconv"
	"strings"
	"sync"
	"time"
	"unsafe"

	"golang.org/x/crypto/ssh"
)

var (
	oldReq, newReq   clientcommon.SshExecRequest
	operationFlag    int
	hostFlag         int
	connectionMethod int
	sessionFlag      int
	runMode          int
	lineFeedCode     int
	timeout          int

	sessionId      string
	loginUser      string
	loginPassword  string
	publicKey      string
	privateKey     string
	passPhrase     string
	hostIP         string
	hostPort       string
	stopCode       string
	terminalType   string
	characterCode  string
	promptString   string
	command        string
	socketFilename string
	socketFilePath string
	tcpTimeout     string
	dataFilePath   string

	closeChan = make(chan bool)

	authMethod ssh.AuthMethod
)

type SSHConnection struct {
	Client *ssh.Client
}

var (
	sshConnections = make(map[string]*SSHConnection)
	mu             sync.Mutex
)

func processCells(cells []C.VTermScreenCell) string {
	var output strings.Builder
	consecutiveZeros := 0

	for i := 0; i < len(cells); i++ {
		cell := cells[i]

		if cell.chars[0] == 0 {
			consecutiveZeros++
			if consecutiveZeros >= 2 {
				break
			}
			output.WriteByte('\n')
			continue
		}
		consecutiveZeros = 0

		r := rune(cell.chars[0])
		output.WriteRune(r)
	}
	return output.String()
}

func SSHDisconnect(host, port string) error {
	mu.Lock()
	defer mu.Unlock()

	addr := fmt.Sprintf("%s:%s", host, port)
	conn, exists := sshConnections[addr]
	if !exists {
		return fmt.Errorf("no active SSH connection to %s", addr)
	}

	if err := conn.Client.Close(); err != nil {
		return fmt.Errorf("failed to close connection to %s: %v", addr, err)
	}

	delete(sshConnections, addr)
	log.Printf("Disconnected from %s\n", addr)
	return nil
}

func runInteractiveCommand(client *ssh.Client, command string, promptString string, lineFeedCode int) (string, error) {
	session, err := client.NewSession()
	if err != nil {
		return "", fmt.Errorf("failed to create SSH session: %v", err)
	}
	defer session.Close()

	modes := ssh.TerminalModes{
		ssh.ECHO:          1,
		ssh.TTY_OP_ISPEED: 14400,
		ssh.TTY_OP_OSPEED: 14400,
	}

	cols, rows := 200, 80
	if err := session.RequestPty("xterm", rows, cols, modes); err != nil {
		return "", fmt.Errorf("request for pseudo terminal failed: %v", err)
	}

	stdin, _ := session.StdinPipe()
	stdout, _ := session.StdoutPipe()
	stderr, _ := session.StderrPipe()

	if err := session.Shell(); err != nil {
		return "", fmt.Errorf("failed to start shell: %v", err)
	}

	combined := io.MultiReader(stdout, stderr)
	reader := bufio.NewReader(combined)
	var outputBuf bytes.Buffer

	// line feed conversion
	switch lineFeedCode {
	case common.JA_LINE_FEED_CODE_LF:
		// linux
		command = strings.ReplaceAll(command, "\r\n", "\n")
	case common.JA_LINE_FEED_CODE_CR:
		// mac
		command = strings.ReplaceAll(command, "\r\n", "\r")
	case common.JA_LINE_FEED_CODE_CRLF:
		// window
		command = strings.ReplaceAll(command, "\r\n", "\n")
		command = strings.ReplaceAll(command, "\n", "\r\n")
	default:
		return "", fmt.Errorf("unknown line feed code : %d", lineFeedCode)
	}

	// Send the command
	fmt.Fprintf(stdin, "%s", command)

	var cleanOutput string
	buf := make([]byte, 2048)

	for {
		n, err := reader.Read(buf)
		if err == io.EOF {
			fmt.Println("EOF reached")
			break
		} else if err != nil {
			fmt.Println("Read error:", err)
			break
		}

		if n > 0 {
			outputBuf.Write(buf[:n])

			// pass to libvterm
			rawOutput := outputBuf.String()
			C.setlocale(C.LC_CTYPE, C.CString("C.utf8"))
			C.setenv(C.CString("TERM"), C.CString("xterm"), 1)

			vterm := C.vterm_new(C.int(rows), C.int(cols))
			defer C.vterm_free(vterm)
			C.vterm_set_utf8(vterm, 1)
			screen := C.vterm_obtain_screen(vterm)
			C.vterm_screen_reset(screen, 1)

			rawBytes := []byte(rawOutput)
			C.vterm_input_write(vterm, (*C.char)(unsafe.Pointer(&rawBytes[0])), C.size_t(len(rawBytes)))
			C.vterm_screen_flush_damage(screen)

			totalCells := rows * cols * 5
			cells := make([]C.VTermScreenCell, totalCells)
			C.fill_vterm_cells(vterm, (*C.VTermScreenCell)(unsafe.Pointer(&cells[0])), C.int(totalCells))

			cleanOutput = processCells(cells)
			cleanOutput = strings.TrimSpace(cleanOutput)

			if strings.HasSuffix(cleanOutput, promptString) {
				break
			}
		}

		// Prevent tight loop
		time.Sleep(50 * time.Millisecond)
	}

	return cleanOutput, nil
}

func runNonInteractiveCommand(client *ssh.Client, command string, lineFeedCode int) (stdout string, stderr string, exitCode int, err error) {
	session, err := client.NewSession()
	if err != nil {
		return "", "", -1, fmt.Errorf("failed to create session: %w", err)
	}
	defer session.Close()

	var stdoutBuf, stderrBuf bytes.Buffer
	session.Stdout = &stdoutBuf
	session.Stderr = &stderrBuf

	// line feed conversion
	switch lineFeedCode {
	case common.JA_LINE_FEED_CODE_LF:
		// linux
		command = strings.ReplaceAll(command, "\r\n", "\n")
	case common.JA_LINE_FEED_CODE_CR:
		// mac
		command = strings.ReplaceAll(command, "\r\n", "\r")
	case common.JA_LINE_FEED_CODE_CRLF:
		// windows
		// no conversion needed — keep \r\n as is
	default:
		return "", "", -1, fmt.Errorf("unknown line feed code : %d", lineFeedCode)
	}

	err = session.Run(command)

	stdout = stdoutBuf.String()
	stderr = stderrBuf.String()

	if err != nil {
		if exitErr, ok := err.(*ssh.ExitError); ok {
			exitCode = exitErr.ExitStatus()
		} else {
			exitCode = -1 // Unknown error
			return stdout, stderr, exitCode, fmt.Errorf("ssh command error: %w", err)
		}
	} else {
		exitCode = 0
	}

	return stdout, stderr, exitCode, nil
}

func handleConnection(conn net.Conn, client *ssh.Client) {
	var sshResp clientcommon.SshExecResponse
	defer conn.Close()

	// Read length prefix (uint32)
	var length uint32
	err := binary.Read(conn, binary.BigEndian, &length)
	if err != nil {
		if err != io.EOF {
			fmt.Fprintf(os.Stderr, "[SshClient] Failed to read message length: %s\n", err)
		}
		return
	}

	// Read JSON data
	data := make([]byte, length)
	_, err = io.ReadFull(conn, data)
	if err != nil {
		fmt.Fprintf(os.Stderr, "[SshClient] Failed to read JSON data: %s\n", err)
		return
	}

	err = json.Unmarshal(data, &newReq)
	if err != nil {
		fmt.Fprintf(os.Stderr, "[SshClient] Failed to unmarshal SSH request: %s\n", err)
		return
	}

	if newReq.SessionFlag == common.JA_SES_OPERATION_FLAG_CONTINUE {
		newReq.RunMode = oldReq.RunMode
		newReq.LineFeedCode = oldReq.LineFeedCode
	} else {
		oldReq = newReq
	}

	if newReq.Command == "close" {
		select {
		case closeChan <- true:
		default:
		}
		return
	}

	if newReq.RunMode == common.JA_RUN_MODE_INTERACTIVE {
		stdout, err := runInteractiveCommand(client, newReq.Command, newReq.PromptString, newReq.LineFeedCode)
		if err != nil {
			fmt.Fprintf(os.Stderr, "[SshClient] Interactive command execution failed. error : %s", err.Error())
		} else {
			sshResp.Stdout = stdout
		}
	} else {
		stdout, stderr, exitCode, err := runNonInteractiveCommand(client, newReq.Command, newReq.LineFeedCode)
		if err != nil {
			fmt.Fprintf(os.Stderr, "[SshClient] Non-Interactive command execution failed. error : %s", err.Error())
		} else {
			sshResp.Stdout = stdout
			sshResp.Stderr = stderr
			sshResp.ExitCode = exitCode
		}
	}

	res, err := json.Marshal(sshResp)
	if err != nil {
		fmt.Fprintf(os.Stderr, "[SshClient] Failed to marshal SSH response. error : %s", err.Error())
		return
	}

	// Send length-prefixed JSON response
	err = binary.Write(conn, binary.BigEndian, uint32(len(res)))
	if err != nil {
		fmt.Fprintf(os.Stderr, "[SshClient] Failed to write response length: %s\n", err)
		return
	}

	// Send result data
	_, err = conn.Write(res)
	if err != nil {
		fmt.Fprintf(os.Stderr, "[SshClient] Failed to write response data: %s\n", err)
		return
	}
}

func loadPrivateKey(privateKeyPath, passPhrase string) (ssh.Signer, error) {
	key, err := os.ReadFile(privateKeyPath)
	if err != nil {
		return nil, err
	}

	if passPhrase == "" {
		return ssh.ParsePrivateKey(key)
	}
	return ssh.ParsePrivateKeyWithPassphrase(key, []byte(passPhrase))
}

func writeSshClientStatus(statusFilePath string, status int) error {
	tmpPath := statusFilePath + ".tmp"

	data := []byte(strconv.Itoa(status) + "\n")

	if err := os.WriteFile(tmpPath, data, 0644); err != nil {
		return fmt.Errorf("failed to write temp status file: %w", err)
	}

	if err := os.Rename(tmpPath, statusFilePath); err != nil {
		return fmt.Errorf("failed to rename status file: %w", err)
	}

	return nil
}

func main() {
	var sshClientStatusFilePath string

	//catch runtime panic errors
	defer func() {
		if r := recover(); r != nil {
			//output stacktrace
			fmt.Fprintf(os.Stderr, "[SshClient] Runtime panic error occurs in client. error : %s", string(debug.Stack()))

			if err := writeSshClientStatus(sshClientStatusFilePath, common.SSHD_CONNECT_FAIL); err != nil {
				fmt.Fprintf(os.Stderr, "[SshClient] Failed to write ssh client fail status: %v\n", err)
			}

			os.Exit(1)
		}
	}()

	if len(os.Args) != 10 {
		fmt.Fprintf(os.Stderr, "[SshClient] Usage: ssh-client <hostip> <hostport> <login-user> <login-password> <public-key> <private-key> <pass-phrase> <socket-file-path> <tcpTimeout>")
		os.Exit(1)
	}

	hostIP = os.Args[1]
	hostPort = os.Args[2]
	loginUser = os.Args[3]
	loginPassword = os.Args[4]
	publicKey = os.Args[5]
	privateKey = os.Args[6]
	passPhrase = os.Args[7]
	socketFilePath = os.Args[8]
	tcpTimeout = os.Args[9]

	sshClientStatusFilePath = strings.TrimSuffix(socketFilePath, filepath.Ext(socketFilePath)) + ".status"

	defer os.Remove(sshClientStatusFilePath)

	if loginPassword == "" {
		signer, err := loadPrivateKey(privateKey, passPhrase)
		if err != nil {
			fmt.Fprintf(os.Stderr, "[SshClient] Failed to load private key. error : %s", err.Error())

			if err := writeSshClientStatus(sshClientStatusFilePath, common.SSHD_CONNECT_FAIL); err != nil {
				fmt.Fprintf(os.Stderr, "[SshClient] Failed to write ssh client fail status: %v\n", err)
			}

			os.Exit(1)
		}
		authMethod = ssh.PublicKeys(signer)
	} else {
		authMethod = ssh.Password(loginPassword)
	}

	secs, err := strconv.Atoi(tcpTimeout)
	if err != nil {
		fmt.Fprintf(os.Stderr, "[SshClient] Invalid tcp timeout: %v\n", err)

		if err := writeSshClientStatus(sshClientStatusFilePath, common.SSHD_CONNECT_FAIL); err != nil {
			fmt.Fprintf(os.Stderr, "[SshClient] Failed to write ssh client fail status: %v\n", err)
		}

		os.Exit(1)
	}

	timeout := time.Duration(secs) * time.Second

	config := &ssh.ClientConfig{
		User:            loginUser,
		Auth:            []ssh.AuthMethod{authMethod},
		HostKeyCallback: ssh.InsecureIgnoreHostKey(),
		Timeout:         timeout,
	}

	addr := net.JoinHostPort(hostIP, hostPort)

	client, err := ssh.Dial("tcp", addr, config)
	if err != nil {
		fmt.Fprintf(os.Stderr, "[SshClient] Failed to connect ssh host. error : %s", err.Error())

		if err := writeSshClientStatus(sshClientStatusFilePath, common.SSHD_CONNECT_FAIL); err != nil {
			fmt.Fprintf(os.Stderr, "[SshClient] Failed to write ssh client fail status: %v\n", err)
		}

		os.Exit(1)
	}
	defer client.Close()

	listener, err := net.Listen("unix", socketFilePath)
	if err != nil {
		fmt.Fprintf(os.Stderr, "[SshClient] Failed to listen on socket : %s", err.Error())

		if err := writeSshClientStatus(sshClientStatusFilePath, common.SSHD_CONNECT_FAIL); err != nil {
			fmt.Fprintf(os.Stderr, "[SshClient] Failed to write ssh client fail status: %v\n", err)
		}

		os.Exit(1)
	}

	defer listener.Close()
	defer os.Remove(socketFilePath)

	go func() {
		for {
			conn, err := listener.Accept()
			if err != nil {
				continue
			}
			go handleConnection(conn, client)
		}
	}()

	if err := writeSshClientStatus(sshClientStatusFilePath, common.SSHD_CONNECT_SUCCESS); err != nil {
		fmt.Fprintf(os.Stderr, "[SshClient] Failed to write ssh client success status: %v\n", err)
		os.Exit(1)
	}

	<-closeChan
	listener.Close()
}
