From b50abccae04c9555ee1978ed8035436319ac1b84 Mon Sep 17 00:00:00 2001 From: Khairul Hidayat Date: Sat, 9 Nov 2024 14:37:09 +0000 Subject: [PATCH] feat: update --- frontend/components/ui/alert.tsx | 58 +++++++ frontend/components/ui/button.tsx | 19 +++ frontend/components/ui/os-icons.tsx | 58 +++++++ frontend/lib/api.ts | 6 + .../hosts/components/credentials-section.tsx | 16 ++ frontend/pages/hosts/components/form.tsx | 139 +++------------- .../pages/hosts/components/hosts-list.tsx | 47 ++++-- frontend/pages/hosts/components/incus.tsx | 57 +++++++ frontend/pages/hosts/components/pve.tsx | 46 ++++++ frontend/pages/hosts/components/ssh.tsx | 33 ++++ frontend/pages/hosts/hooks/query.ts | 17 ++ frontend/pages/hosts/types.ts | 6 + frontend/pages/terminal/page.tsx | 2 +- server/app/hosts/repository.go | 24 +-- server/app/hosts/router.go | 22 ++- server/app/hosts/utils.go | 54 ++++++ server/app/keychains/repository.go | 30 +++- server/app/keychains/router.go | 4 +- server/app/ws/term.go | 26 ++- server/app/ws/term_ssh.go | 92 ++--------- server/lib/crypto.go | 4 + server/lib/os.go | 32 ++++ server/lib/ssh.go | 154 ++++++++++++++++++ server/models/host.go | 24 +++ server/models/keychain.go | 1 + server/tests/keychains_test.go | 11 +- 26 files changed, 728 insertions(+), 254 deletions(-) create mode 100644 frontend/components/ui/alert.tsx create mode 100644 frontend/components/ui/button.tsx create mode 100644 frontend/components/ui/os-icons.tsx create mode 100644 frontend/pages/hosts/components/credentials-section.tsx create mode 100644 frontend/pages/hosts/components/incus.tsx create mode 100644 frontend/pages/hosts/components/pve.tsx create mode 100644 frontend/pages/hosts/components/ssh.tsx create mode 100644 frontend/pages/hosts/types.ts create mode 100644 server/app/hosts/utils.go create mode 100644 server/lib/os.go create mode 100644 server/lib/ssh.go diff --git a/frontend/components/ui/alert.tsx b/frontend/components/ui/alert.tsx new file mode 100644 index 0000000..44b75b4 --- /dev/null +++ b/frontend/components/ui/alert.tsx @@ -0,0 +1,58 @@ +import { Card, GetProps, styled, Text, XStack } from "tamagui"; +import Icons from "./icons"; + +const AlertFrame = styled(Card, { + px: "$4", + py: "$3", + bordered: true, + variants: { + variant: { + default: {}, + error: { + backgroundColor: "$red2", + borderColor: "$red5", + }, + }, + } as const, +}); + +const icons: Record = { + error: "alert-circle-outline", +}; + +type AlertProps = GetProps; + +const Alert = ({ children, variant = "default", ...props }: AlertProps) => { + return ( + + + {icons[variant] != null && ( + + )} + + + {children} + + + + ); +}; + +type ErrorAlert = AlertProps & { + error?: unknown | null; +}; + +export const ErrorAlert = ({ error, ...props }: ErrorAlert) => { + if (!error) { + return null; + } + + const message = (error as any)?.message || "Something went wrong"; + return ( + + {message} + + ); +}; + +export default Alert; diff --git a/frontend/components/ui/button.tsx b/frontend/components/ui/button.tsx new file mode 100644 index 0000000..ea9a991 --- /dev/null +++ b/frontend/components/ui/button.tsx @@ -0,0 +1,19 @@ +import React from "react"; +import { GetProps, Button as BaseButton, Spinner } from "tamagui"; + +type ButtonProps = GetProps & { + isDisabled?: boolean; + isLoading?: boolean; +}; + +const Button = ({ icon, isLoading, isDisabled, ...props }: ButtonProps) => { + return ( + : icon} + disabled={isLoading || isDisabled || props.disabled} + {...props} + /> + ); +}; + +export default Button; diff --git a/frontend/components/ui/os-icons.tsx b/frontend/components/ui/os-icons.tsx new file mode 100644 index 0000000..92e348b --- /dev/null +++ b/frontend/components/ui/os-icons.tsx @@ -0,0 +1,58 @@ +import { ComponentPropsWithoutRef } from "react"; +import Icons from "./icons"; + +/* +var osMap = map[string]string{ + "arch": "arch", + "ubuntu": "ubuntu", + "kali": "kali", + "raspbian": "raspbian", + "pop": "pop", + "debian": "debian", + "fedora": "fedora", + "centos": "centos", + "alpine": "alpine", + "mint": "mint", + "suse": "suse", + "darwin": "macos", + "windows": "windows", + "msys": "windows", + "linux": "linux", +} +*/ + +const icons: Record = { + ubuntu: { name: "ubuntu" }, + debian: { name: "debian" }, + arch: { name: "arch" }, + mint: { name: "linux-mint" }, + raspbian: { name: "raspberry-pi" }, + fedora: { name: "fedora" }, + centos: { name: "centos" }, + macos: { name: "apple" }, + windows: { name: "microsoft-windows" }, + linux: { name: "linux" }, +}; + +type OSIconsProps = Omit, "name"> & { + name?: string | null; + fallback?: string; +}; + +const OSIcons = ({ name, fallback, ...props }: OSIconsProps) => { + const icon = icons[name || ""]; + + if (!icon) { + return fallback ? : null; + } + + return ( + + ); +}; + +export default OSIcons; diff --git a/frontend/lib/api.ts b/frontend/lib/api.ts index 21f91ba..3f59476 100644 --- a/frontend/lib/api.ts +++ b/frontend/lib/api.ts @@ -6,6 +6,12 @@ export const BASE_WS_URL = BASE_API_URL.replace("http", "ws"); const api = ofetch.create({ baseURL: BASE_API_URL, + onResponseError: (error) => { + if (error.response._data) { + const message = error.response._data.message; + throw new Error(message || "Something went wrong"); + } + }, }); export const queryClient = new QueryClient(); diff --git a/frontend/pages/hosts/components/credentials-section.tsx b/frontend/pages/hosts/components/credentials-section.tsx new file mode 100644 index 0000000..fbc9d3a --- /dev/null +++ b/frontend/pages/hosts/components/credentials-section.tsx @@ -0,0 +1,16 @@ +import Icons from "@/components/ui/icons"; +import React from "react"; +import { Button, Label, XStack } from "tamagui"; + +export default function CredentialsSection() { + return ( + + + + + ); +} diff --git a/frontend/pages/hosts/components/form.tsx b/frontend/pages/hosts/components/form.tsx index 61dba79..04b5172 100644 --- a/frontend/pages/hosts/components/form.tsx +++ b/frontend/pages/hosts/components/form.tsx @@ -1,22 +1,19 @@ import Icons from "@/components/ui/icons"; import Modal from "@/components/ui/modal"; import { SelectField } from "@/components/ui/select"; -import { useZForm, UseZFormReturn } from "@/hooks/useZForm"; -import api from "@/lib/api"; +import { useZForm } from "@/hooks/useZForm"; import { createDisclosure } from "@/lib/utils"; -import { useQuery } from "@tanstack/react-query"; import React from "react"; -import { Button, Label, ScrollView, XStack } from "tamagui"; -import { - FormSchema, - formSchema, - incusTypes, - pveTypes, - typeOptions, -} from "../schema/form"; +import { ScrollView, XStack } from "tamagui"; +import { FormSchema, formSchema, typeOptions } from "../schema/form"; import { InputField } from "@/components/ui/input"; import FormField from "@/components/ui/form"; -import { useKeychains, useSaveHost } from "../hooks/query"; +import { useSaveHost } from "../hooks/query"; +import { ErrorAlert } from "@/components/ui/alert"; +import Button from "@/components/ui/button"; +import { PVEFormFields } from "./pve"; +import { IncusFormFields } from "./incus"; +import { SSHFormFields } from "./ssh"; export const hostFormModal = createDisclosure(); @@ -26,7 +23,6 @@ const HostForm = () => { const isEditing = data?.id != null; const type = form.watch("type"); - const keys = useKeychains(); const saveMutation = useSaveHost(); const onSubmit = form.handleSubmit((values) => { @@ -44,6 +40,8 @@ const HostForm = () => { title="Host" description={`${isEditing ? "Edit" : "Add new"} host.`} > + + @@ -62,47 +60,17 @@ const HostForm = () => { form={form} name="port" keyboardType="number-pad" - placeholder="SSH Port" + placeholder="Port" /> - {type === "pve" && } - {type === "incus" && } - - - - - - - - ({ - label: key.label, - value: key.id, - }))} - /> - - - {type === "ssh" && ( - - ({ - label: key.label, - value: key.id, - }))} - /> - - )} + {type === "ssh" ? ( + + ) : type === "pve" ? ( + + ) : type === "incus" ? ( + + ) : null} @@ -113,6 +81,7 @@ const HostForm = () => { flex={1} icon={} onPress={onSubmit} + isLoading={saveMutation.isPending} > Save @@ -121,72 +90,4 @@ const HostForm = () => { ); }; -type MiscFormFieldProps = { - form: UseZFormReturn; -}; - -const PVEFormFields = ({ form }: MiscFormFieldProps) => { - return ( - <> - - - - - - - - - - - ); -}; - -const IncusFormFields = ({ form }: MiscFormFieldProps) => { - const type = form.watch("metadata.type"); - - return ( - <> - - - - - - - {type === "lxc" && ( - <> - - - - - - - - )} - - ); -}; - export default HostForm; diff --git a/frontend/pages/hosts/components/hosts-list.tsx b/frontend/pages/hosts/components/hosts-list.tsx index a620547..b299dec 100644 --- a/frontend/pages/hosts/components/hosts-list.tsx +++ b/frontend/pages/hosts/components/hosts-list.tsx @@ -8,8 +8,13 @@ import Icons from "@/components/ui/icons"; import SearchInput from "@/components/ui/search-input"; import { useTermSession } from "@/stores/terminal-sessions"; import { hostFormModal } from "./form"; +import OSIcons from "@/components/ui/os-icons"; -const HostsList = () => { +type HostsListProps = { + allowEdit?: boolean; +}; + +const HostsList = ({ allowEdit = true }: HostsListProps) => { const openSession = useTermSession((i) => i.push); const navigation = useNavigation(); const [search, setSearch] = useState(""); @@ -37,6 +42,7 @@ const HostsList = () => { }, [hosts.data, search]); const onOpen = (host: any) => { + if (!allowEdit) return; hostFormModal.onOpen(host); }; @@ -88,9 +94,9 @@ const HostsList = () => { flexBasis="100%" cursor="pointer" $gtXs={{ flexBasis: "50%" }} - $gtSm={{ flexBasis: "33.3%" }} - $gtMd={{ flexBasis: "25%" }} - $gtLg={{ flexBasis: "20%" }} + $gtMd={{ flexBasis: "33.3%" }} + $gtLg={{ flexBasis: "25%" }} + $gtXl={{ flexBasis: "20%" }} p="$2" group numberOfTaps={2} @@ -99,6 +105,13 @@ const HostsList = () => { > + + {host.label} @@ -106,18 +119,20 @@ const HostsList = () => { - + {allowEdit && ( + + )} diff --git a/frontend/pages/hosts/components/incus.tsx b/frontend/pages/hosts/components/incus.tsx new file mode 100644 index 0000000..a15a26e --- /dev/null +++ b/frontend/pages/hosts/components/incus.tsx @@ -0,0 +1,57 @@ +import FormField from "@/components/ui/form"; +import { MiscFormFieldProps } from "../types"; +import { InputField } from "@/components/ui/input"; +import { SelectField } from "@/components/ui/select"; +import { incusTypes } from "../schema/form"; +import CredentialsSection from "./credentials-section"; +import { useKeychainsOptions } from "../hooks/query"; + +export const IncusFormFields = ({ form }: MiscFormFieldProps) => { + const keys = useKeychainsOptions(); + const type = form.watch("metadata.type"); + + return ( + <> + + + + + + + {type === "lxc" && ( + <> + + + + + + + + )} + + + + + i.type === "cert")} + /> + + + ); +}; diff --git a/frontend/pages/hosts/components/pve.tsx b/frontend/pages/hosts/components/pve.tsx new file mode 100644 index 0000000..6cb5af8 --- /dev/null +++ b/frontend/pages/hosts/components/pve.tsx @@ -0,0 +1,46 @@ +import FormField from "@/components/ui/form"; +import { MiscFormFieldProps } from "../types"; +import { InputField } from "@/components/ui/input"; +import { SelectField } from "@/components/ui/select"; +import { pveTypes } from "../schema/form"; +import { useKeychainsOptions } from "../hooks/query"; +import CredentialsSection from "./credentials-section"; + +export const PVEFormFields = ({ form }: MiscFormFieldProps) => { + const keys = useKeychainsOptions(); + + return ( + <> + + + + + + + + + + + + + + i.type === "pve")} + /> + + + ); +}; diff --git a/frontend/pages/hosts/components/ssh.tsx b/frontend/pages/hosts/components/ssh.tsx new file mode 100644 index 0000000..293e685 --- /dev/null +++ b/frontend/pages/hosts/components/ssh.tsx @@ -0,0 +1,33 @@ +import FormField from "@/components/ui/form"; +import { MiscFormFieldProps } from "../types"; +import { SelectField } from "@/components/ui/select"; +import CredentialsSection from "./credentials-section"; +import { useKeychainsOptions } from "../hooks/query"; + +export const SSHFormFields = ({ form }: MiscFormFieldProps) => { + const keys = useKeychainsOptions(); + + return ( + <> + + + + i.type === "user")} + /> + + + + i.type === "rsa")} + /> + + + ); +}; diff --git a/frontend/pages/hosts/hooks/query.ts b/frontend/pages/hosts/hooks/query.ts index d7cfc3c..1961d6e 100644 --- a/frontend/pages/hosts/hooks/query.ts +++ b/frontend/pages/hosts/hooks/query.ts @@ -1,6 +1,7 @@ import { useMutation, useQuery } from "@tanstack/react-query"; import { FormSchema } from "../schema/form"; import api, { queryClient } from "@/lib/api"; +import { useMemo } from "react"; export const useKeychains = () => { return useQuery({ @@ -10,6 +11,22 @@ export const useKeychains = () => { }); }; +export const useKeychainsOptions = () => { + const keys = useKeychains(); + + const data = useMemo(() => { + const items: any[] = keys.data || []; + + return items.map((key: any) => ({ + type: key.type, + label: key.label, + value: key.id, + })); + }, [keys.data]); + + return data; +}; + export const useSaveHost = () => { return useMutation({ mutationFn: async (body: FormSchema) => { diff --git a/frontend/pages/hosts/types.ts b/frontend/pages/hosts/types.ts new file mode 100644 index 0000000..e363d95 --- /dev/null +++ b/frontend/pages/hosts/types.ts @@ -0,0 +1,6 @@ +import { UseZFormReturn } from "@/hooks/useZForm"; +import { FormSchema } from "./schema/form"; + +export type MiscFormFieldProps = { + form: UseZFormReturn; +}; diff --git a/frontend/pages/terminal/page.tsx b/frontend/pages/terminal/page.tsx index a7877e4..178b45d 100644 --- a/frontend/pages/terminal/page.tsx +++ b/frontend/pages/terminal/page.tsx @@ -35,7 +35,7 @@ const TerminalPage = () => { style={{ flex: 1 }} page={curSession} onChangePage={setSession} - EmptyComponent={HostsList} + EmptyComponent={() => } > {sessions.map((session) => ( diff --git a/server/app/hosts/repository.go b/server/app/hosts/repository.go index 4e5811d..d7c2aa1 100644 --- a/server/app/hosts/repository.go +++ b/server/app/hosts/repository.go @@ -8,7 +8,7 @@ import ( type Hosts struct{ db *gorm.DB } -func NewHostsRepository() *Hosts { +func NewRepository() *Hosts { return &Hosts{db: db.Get()} } @@ -19,13 +19,7 @@ func (r *Hosts) GetAll() ([]*models.Host, error) { return rows, ret.Error } -type GetHostResult struct { - Host *models.Host - Key map[string]interface{} - AltKey map[string]interface{} -} - -func (r *Hosts) Get(id string) (*GetHostResult, error) { +func (r *Hosts) Get(id string) (*models.HostDecrypted, error) { var host models.Host ret := r.db.Joins("Key").Joins("AltKey").Where("hosts.id = ?", id).First(&host) @@ -33,17 +27,9 @@ func (r *Hosts) Get(id string) (*GetHostResult, error) { return nil, ret.Error } - res := &GetHostResult{Host: &host} - - if host.Key.Data != "" { - if err := host.Key.DecryptData(&res.Key); err != nil { - return nil, err - } - } - if host.AltKey.Data != "" { - if err := host.AltKey.DecryptData(&res.AltKey); err != nil { - return nil, err - } + res, err := host.DecryptKeys() + if err != nil { + return nil, err } return res, ret.Error diff --git a/server/app/hosts/router.go b/server/app/hosts/router.go index d185876..945afbf 100644 --- a/server/app/hosts/router.go +++ b/server/app/hosts/router.go @@ -19,7 +19,7 @@ func Router(app *fiber.App) { } func getAll(c *fiber.Ctx) error { - repo := NewHostsRepository() + repo := NewRepository() rows, err := repo.GetAll() if err != nil { return utils.ResponseError(c, err, 500) @@ -36,7 +36,7 @@ func create(c *fiber.Ctx) error { return utils.ResponseError(c, err, 500) } - repo := NewHostsRepository() + repo := NewRepository() item := &models.Host{ Type: body.Type, Label: body.Label, @@ -47,6 +47,13 @@ func create(c *fiber.Ctx) error { KeyID: body.KeyID, AltKeyID: body.AltKeyID, } + + osName, err := tryConnect(item) + if err != nil { + return utils.ResponseError(c, fmt.Errorf("cannot connect to the host: %s", err), 500) + } + item.OS = osName + if err := repo.Create(item); err != nil { return utils.ResponseError(c, err, 500) } @@ -60,7 +67,7 @@ func update(c *fiber.Ctx) error { return utils.ResponseError(c, err, 500) } - repo := NewHostsRepository() + repo := NewRepository() id := c.Params("id") exist, _ := repo.Exists(id) @@ -79,6 +86,13 @@ func update(c *fiber.Ctx) error { KeyID: body.KeyID, AltKeyID: body.AltKeyID, } + + osName, err := tryConnect(item) + if err != nil { + return utils.ResponseError(c, fmt.Errorf("cannot connect to the host: %s", err), 500) + } + item.OS = osName + if err := repo.Update(item); err != nil { return utils.ResponseError(c, err, 500) } @@ -87,7 +101,7 @@ func update(c *fiber.Ctx) error { } func delete(c *fiber.Ctx) error { - repo := NewHostsRepository() + repo := NewRepository() id := c.Params("id") exist, _ := repo.Exists(id) diff --git a/server/app/hosts/utils.go b/server/app/hosts/utils.go new file mode 100644 index 0000000..837edab --- /dev/null +++ b/server/app/hosts/utils.go @@ -0,0 +1,54 @@ +package hosts + +import ( + "fmt" + + "rul.sh/vaulterm/app/keychains" + "rul.sh/vaulterm/lib" + "rul.sh/vaulterm/models" +) + +func tryConnect(host *models.Host) (string, error) { + keyRepo := keychains.NewRepository() + + var key map[string]interface{} + var altKey map[string]interface{} + + if host.KeyID != nil { + keychain, _ := keyRepo.Get(*host.KeyID) + if keychain == nil { + return "", fmt.Errorf("key %s not found", *host.KeyID) + } + keychain.DecryptData(&key) + } + if host.AltKeyID != nil { + keychain, _ := keyRepo.Get(*host.AltKeyID) + if keychain == nil { + return "", fmt.Errorf("key %s not found", *host.KeyID) + } + keychain.DecryptData(&altKey) + } + + if host.Type == "ssh" { + c := lib.NewSSHClient(&lib.SSHClientConfig{ + HostName: host.Host, + Port: host.Port, + Key: key, + AltKey: altKey, + }) + + con, err := c.Connect() + if err != nil { + return "", err + } + + os, err := c.GetOS(c, con) + if err != nil { + return "", err + } + + return os, nil + } + + return "", nil +} diff --git a/server/app/keychains/repository.go b/server/app/keychains/repository.go index 07a2020..1025fc6 100644 --- a/server/app/keychains/repository.go +++ b/server/app/keychains/repository.go @@ -8,7 +8,7 @@ import ( type Keychains struct{ db *gorm.DB } -func NewKeychainsRepository() *Keychains { +func NewRepository() *Keychains { return &Keychains{db: db.Get()} } @@ -22,3 +22,31 @@ func (r *Keychains) GetAll() ([]*models.Keychain, error) { func (r *Keychains) Create(item *models.Keychain) error { return r.db.Create(item).Error } + +func (r *Keychains) Get(id string) (*models.Keychain, error) { + var keychain models.Keychain + if err := r.db.Where("id = ?", id).First(&keychain).Error; err != nil { + return nil, err + } + + return &keychain, nil +} + +type KeychainDecrypted struct { + models.Keychain + Data map[string]interface{} +} + +func (r *Keychains) GetDecrypted(id string) (*KeychainDecrypted, error) { + keychain, err := r.Get(id) + if err != nil { + return nil, err + } + + var data map[string]interface{} + if err := keychain.DecryptData(&data); err != nil { + return nil, err + } + + return &KeychainDecrypted{Keychain: *keychain, Data: data}, nil +} diff --git a/server/app/keychains/router.go b/server/app/keychains/router.go index c02266e..ca585b9 100644 --- a/server/app/keychains/router.go +++ b/server/app/keychains/router.go @@ -16,7 +16,7 @@ func Router(app *fiber.App) { } func getAll(c *fiber.Ctx) error { - repo := NewKeychainsRepository() + repo := NewRepository() rows, err := repo.GetAll() if err != nil { return utils.ResponseError(c, err, 500) @@ -33,7 +33,7 @@ func create(c *fiber.Ctx) error { return utils.ResponseError(c, err, 500) } - repo := NewKeychainsRepository() + repo := NewRepository() item := &models.Keychain{ Type: body.Type, diff --git a/server/app/ws/term.go b/server/app/ws/term.go index f09f04a..7899a1d 100644 --- a/server/app/ws/term.go +++ b/server/app/ws/term.go @@ -6,13 +6,14 @@ import ( "github.com/gofiber/contrib/websocket" "rul.sh/vaulterm/app/hosts" "rul.sh/vaulterm/lib" + "rul.sh/vaulterm/models" "rul.sh/vaulterm/utils" ) func HandleTerm(c *websocket.Conn) { hostId := c.Query("hostId") - hostRepo := hosts.NewHostsRepository() + hostRepo := hosts.NewRepository() data, err := hostRepo.Get(hostId) if data == nil { @@ -33,30 +34,27 @@ func HandleTerm(c *websocket.Conn) { } } -func sshHandler(c *websocket.Conn, data *hosts.GetHostResult) { - username, _ := data.Key["username"].(string) - password, _ := data.Key["password"].(string) - - cfg := &SSHConfig{ +func sshHandler(c *websocket.Conn, data *models.HostDecrypted) { + cfg := lib.NewSSHClient(&lib.SSHClientConfig{ HostName: data.Host.Host, - Port: data.Host.Port, - User: username, - Password: password, - } + Port: data.Port, + Key: data.Key, + AltKey: data.AltKey, + }) if err := NewSSHWebsocketSession(c, cfg); err != nil { c.WriteMessage(websocket.TextMessage, []byte(err.Error())) } } -func pveHandler(c *websocket.Conn, data *hosts.GetHostResult) { +func pveHandler(c *websocket.Conn, data *models.HostDecrypted) { client := c.Query("client") username, _ := data.Key["username"].(string) password, _ := data.Key["password"].(string) pve := &lib.PVEServer{ HostName: data.Host.Host, - Port: data.Host.Port, + Port: data.Port, Username: username, Password: password, } @@ -84,7 +82,7 @@ func pveHandler(c *websocket.Conn, data *hosts.GetHostResult) { } } -func incusHandler(c *websocket.Conn, data *hosts.GetHostResult) { +func incusHandler(c *websocket.Conn, data *models.HostDecrypted) { shell := c.Query("shell") cert, _ := data.Key["cert"].(string) @@ -97,7 +95,7 @@ func incusHandler(c *websocket.Conn, data *hosts.GetHostResult) { incus := &lib.IncusServer{ HostName: data.Host.Host, - Port: data.Host.Port, + Port: data.Port, ClientCert: cert, ClientKey: key, } diff --git a/server/app/ws/term_ssh.go b/server/app/ws/term_ssh.go index 9827d76..82dfadd 100644 --- a/server/app/ws/term_ssh.go +++ b/server/app/ws/term_ssh.go @@ -1,101 +1,37 @@ package ws import ( - "fmt" "io" "log" "strconv" "strings" "github.com/gofiber/contrib/websocket" - "golang.org/x/crypto/ssh" + "rul.sh/vaulterm/lib" ) -type SSHConfig struct { - HostName string - User string - Password string - Port int - PrivateKey string - PrivateKeyPassphrase string -} - -func NewSSHWebsocketSession(c *websocket.Conn, cfg *SSHConfig) error { - // Set up SSH client configuration - port := cfg.Port - if port == 0 { - port = 22 - } - auth := []ssh.AuthMethod{ - ssh.Password(cfg.Password), - } - - if cfg.PrivateKey != "" { - var err error - var signer ssh.Signer - - if cfg.PrivateKeyPassphrase != "" { - signer, err = ssh.ParsePrivateKeyWithPassphrase([]byte(cfg.PrivateKey), []byte(cfg.PrivateKeyPassphrase)) - } else { - signer, err = ssh.ParsePrivateKey([]byte(cfg.PrivateKey)) - } - - if err != nil { - return fmt.Errorf("unable to parse private key: %v", err) - } - auth = append(auth, ssh.PublicKeys(signer)) - } - - sshConfig := &ssh.ClientConfig{ - User: cfg.User, - Auth: auth, - HostKeyCallback: ssh.InsecureIgnoreHostKey(), - } - - // Connect to SSH server - hostName := fmt.Sprintf("%s:%d", cfg.HostName, port) - sshConn, err := ssh.Dial("tcp", hostName, sshConfig) +func NewSSHWebsocketSession(c *websocket.Conn, client *lib.SSHClient) error { + con, err := client.Connect() if err != nil { + log.Printf("error connecting to SSH: %v", err) return err } - defer sshConn.Close() + defer con.Close() - // Start an SSH shell session - session, err := sshConn.NewSession() + shell, err := client.StartPtyShell(con) if err != nil { + log.Printf("error starting SSH shell: %v", err) return err } + + session := shell.Session defer session.Close() - stdoutPipe, err := session.StdoutPipe() - if err != nil { - return err - } - - stderrPipe, err := session.StderrPipe() - if err != nil { - return err - } - - stdinPipe, err := session.StdinPipe() - if err != nil { - return err - } - - err = session.RequestPty("xterm-256color", 80, 24, ssh.TerminalModes{}) - if err != nil { - return err - } - - if err := session.Shell(); err != nil { - return err - } - // Goroutine to send SSH stdout to WebSocket go func() { buf := make([]byte, 1024) for { - n, err := stdoutPipe.Read(buf) + n, err := shell.Stdout.Read(buf) if err != nil { if err != io.EOF { log.Printf("error reading from SSH stdout: %v", err) @@ -114,7 +50,7 @@ func NewSSHWebsocketSession(c *websocket.Conn, cfg *SSHConfig) error { go func() { buf := make([]byte, 1024) for { - n, err := stderrPipe.Read(buf) + n, err := shell.Stderr.Read(buf) if err != nil { if err != io.EOF { log.Printf("error reading from SSH stderr: %v", err) @@ -135,6 +71,7 @@ func NewSSHWebsocketSession(c *websocket.Conn, cfg *SSHConfig) error { for { _, msg, err := c.ReadMessage() if err != nil { + log.Printf("error reading from websocket: %v", err) break } @@ -148,8 +85,10 @@ func NewSSHWebsocketSession(c *websocket.Conn, cfg *SSHConfig) error { continue } - stdinPipe.Write(msg) + shell.Stdin.Write(msg) } + + log.Println("SSH session closed") }() // Wait for the SSH session to close @@ -158,6 +97,5 @@ func NewSSHWebsocketSession(c *websocket.Conn, cfg *SSHConfig) error { return err } - log.Println("SSH session ended normally") return nil } diff --git a/server/lib/crypto.go b/server/lib/crypto.go index 0453845..39ecae3 100644 --- a/server/lib/crypto.go +++ b/server/lib/crypto.go @@ -114,6 +114,10 @@ func Decrypt(encrypted string) (string, error) { return "", err } + if len(data) < 16 { + return "", fmt.Errorf("invalid encrypted data") + } + block, err := aes.NewCipher(keyDec) if err != nil { return "", err diff --git a/server/lib/os.go b/server/lib/os.go new file mode 100644 index 0000000..0d25367 --- /dev/null +++ b/server/lib/os.go @@ -0,0 +1,32 @@ +package lib + +import "strings" + +// Map of OS identifiers and their corresponding names +var osMap = map[string]string{ + "arch": "arch", + "ubuntu": "ubuntu", + "kali": "kali", + "raspbian": "raspbian", + "pop": "pop", + "debian": "debian", + "fedora": "fedora", + "centos": "centos", + "alpine": "alpine", + "mint": "mint", + "suse": "suse", + "darwin": "macos", + "windows": "windows", + "msys": "windows", + "linux": "linux", +} + +func DetectOS(str string) string { + str = strings.ToLower(str) + for keyword, osName := range osMap { + if strings.Contains(str, keyword) { + return osName + } + } + return "" +} diff --git a/server/lib/ssh.go b/server/lib/ssh.go new file mode 100644 index 0000000..1fcb39e --- /dev/null +++ b/server/lib/ssh.go @@ -0,0 +1,154 @@ +package lib + +import ( + "fmt" + "io" + + "golang.org/x/crypto/ssh" +) + +type SSHClient struct { + HostName string + User string + Password string + Port int + PrivateKey string + PrivateKeyPassphrase string +} + +type SSHClientConfig struct { + HostName string + Port int + Key map[string]interface{} + AltKey map[string]interface{} +} + +func NewSSHClient(cfg *SSHClientConfig) *SSHClient { + username, _ := cfg.Key["username"].(string) + password, _ := cfg.Key["password"].(string) + privateKey, _ := cfg.AltKey["private"].(string) + passphrase, _ := cfg.AltKey["passphrase"].(string) + + return &SSHClient{ + HostName: cfg.HostName, + User: username, + Password: password, + Port: cfg.Port, + PrivateKey: privateKey, + PrivateKeyPassphrase: passphrase, + } +} + +func (s *SSHClient) Connect() (*ssh.Client, error) { + // Set up SSH client configuration + port := s.Port + if port == 0 { + port = 22 + } + auth := []ssh.AuthMethod{ + ssh.Password(s.Password), + } + + if s.PrivateKey != "" { + var err error + var signer ssh.Signer + + if s.PrivateKeyPassphrase != "" { + signer, err = ssh.ParsePrivateKeyWithPassphrase([]byte(s.PrivateKey), []byte(s.PrivateKeyPassphrase)) + } else { + signer, err = ssh.ParsePrivateKey([]byte(s.PrivateKey)) + } + + if err != nil { + return nil, fmt.Errorf("unable to parse private key: %v", err) + } + auth = append(auth, ssh.PublicKeys(signer)) + } + + sshConfig := &ssh.ClientConfig{ + User: s.User, + Auth: auth, + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + } + + // Connect to SSH server + hostName := fmt.Sprintf("%s:%d", s.HostName, port) + sshConn, err := ssh.Dial("tcp", hostName, sshConfig) + if err != nil { + return nil, err + } + + return sshConn, nil +} + +type PtyShellRes struct { + Stdout io.Reader + Stderr io.Reader + Stdin io.WriteCloser + Session *ssh.Session +} + +func (s *SSHClient) StartPtyShell(sshConn *ssh.Client) (res *PtyShellRes, err error) { + // Start an SSH shell session + session, err := sshConn.NewSession() + if err != nil { + return nil, err + } + + stdoutPipe, err := session.StdoutPipe() + if err != nil { + return nil, err + } + + stderrPipe, err := session.StderrPipe() + if err != nil { + return nil, err + } + + stdinPipe, err := session.StdinPipe() + if err != nil { + return nil, err + } + + err = session.RequestPty("xterm-256color", 80, 24, ssh.TerminalModes{}) + if err != nil { + return nil, err + } + + if err := session.Shell(); err != nil { + return nil, err + } + + return &PtyShellRes{ + Stdout: stdoutPipe, + Stderr: stderrPipe, + Stdin: stdinPipe, + Session: session, + }, nil +} + +func (s *SSHClient) Exec(sshConn *ssh.Client, command string) (string, error) { + // Start an SSH shell session + session, err := sshConn.NewSession() + if err != nil { + return "", err + } + defer session.Close() + + // Execute the command + output, err := session.CombinedOutput(command) + if err != nil { + return "", err + } + + return string(output), nil +} + +func (s *SSHClient) GetOS(client *SSHClient, con *ssh.Client) (string, error) { + out, err := client.Exec(con, "cat /etc/os-release || uname -a || systeminfo") + if err != nil { + return "", err + } + + return DetectOS(out), nil +} diff --git a/server/models/host.go b/server/models/host.go index af8e2f0..404223b 100644 --- a/server/models/host.go +++ b/server/models/host.go @@ -18,6 +18,7 @@ type Host struct { Label string `json:"label"` Host string `json:"host" gorm:"type:varchar(64)"` Port int `json:"port" gorm:"type:smallint"` + OS string `json:"os" gorm:"type:varchar(32)"` Metadata datatypes.JSONMap `json:"metadata"` ParentID *string `json:"parentId" gorm:"index:hosts_parent_id_idx;type:varchar(26)"` @@ -30,3 +31,26 @@ type Host struct { Timestamps SoftDeletes } + +type HostDecrypted struct { + Host + Key map[string]interface{} + AltKey map[string]interface{} +} + +func (h *Host) DecryptKeys() (*HostDecrypted, error) { + res := &HostDecrypted{Host: *h} + + if h.Key.Data != "" { + if err := h.Key.DecryptData(&res.Key); err != nil { + return nil, err + } + } + if h.AltKey.Data != "" { + if err := h.AltKey.DecryptData(&res.AltKey); err != nil { + return nil, err + } + } + + return res, nil +} diff --git a/server/models/keychain.go b/server/models/keychain.go index 47262f1..b002e0d 100644 --- a/server/models/keychain.go +++ b/server/models/keychain.go @@ -8,6 +8,7 @@ import ( const ( KeychainTypeUserPass = "user" + KeychainTypePVE = "pve" KeychainTypeRSA = "rsa" KeychainTypeCertificate = "cert" ) diff --git a/server/tests/keychains_test.go b/server/tests/keychains_test.go index 0e96d6d..3273596 100644 --- a/server/tests/keychains_test.go +++ b/server/tests/keychains_test.go @@ -30,7 +30,16 @@ func TestKeychainsCreate(t *testing.T) { } // data := map[string]interface{}{ - // "type": "user", + // "type": "rsa", + // "label": "RSA Key", + // "data": map[string]interface{}{ + // "private": "", + // "passphrase": "", + // }, + // } + + // data := map[string]interface{}{ + // "type": "pve", // "label": "PVE Key", // "data": map[string]interface{}{ // "username": "root@pam",