diff --git a/frontend/app/_providers.tsx b/frontend/app/_providers.tsx index 15dc0d3..56da1e1 100644 --- a/frontend/app/_providers.tsx +++ b/frontend/app/_providers.tsx @@ -12,6 +12,7 @@ import { router, usePathname, useRootNavigationState } from "expo-router"; import { useAuthStore } from "@/stores/auth"; import { PortalProvider } from "tamagui"; import { queryClient } from "@/lib/api"; +import { useAppStore } from "@/stores/app"; type Props = PropsWithChildren; @@ -53,16 +54,24 @@ const AuthProvider = () => { const pathname = usePathname(); const rootNavigationState = useRootNavigationState(); const { isLoggedIn } = useAuthStore(); + const { curServer } = useAppStore(); useEffect(() => { if (!rootNavigationState?.key) { return; } - if (!pathname.startsWith("/auth") && !isLoggedIn) { + if (!curServer && !pathname.startsWith("/server")) { + router.replace("/server"); + return; + } + + const isProtected = !["/auth", "/server"].find((path) => + pathname.startsWith(path) + ); + + if (isProtected && !isLoggedIn) { router.replace("/auth/login"); - } else if (pathname.startsWith("/auth") && isLoggedIn) { - router.replace("/"); } }, [pathname, rootNavigationState, isLoggedIn]); diff --git a/frontend/app/auth/login.tsx b/frontend/app/auth/login.tsx index 0a0da13..7c556ca 100644 --- a/frontend/app/auth/login.tsx +++ b/frontend/app/auth/login.tsx @@ -1,14 +1,3 @@ -import { View, Text, Button } from "tamagui"; -import React from "react"; -import authStore from "@/stores/auth"; +import LoginPage from "@/pages/auth/login"; -export default function LoginPage() { - return ( - - LoginPage - - - ); -} +export default LoginPage; diff --git a/frontend/app/index.tsx b/frontend/app/index.tsx index fac85db..0b7d8d9 100644 --- a/frontend/app/index.tsx +++ b/frontend/app/index.tsx @@ -1,9 +1,15 @@ import React from "react"; import { Redirect } from "expo-router"; import { useTermSession } from "@/stores/terminal-sessions"; +import { useAppStore } from "@/stores/app"; export default function index() { const { sessions, curSession } = useTermSession(); + const { servers, curServer } = useAppStore(); + + if (!servers.length || !curServer) { + return ; + } return ( { - const query = new URLSearchParams(params); + const { token } = useAuthStore(); + const query = new URLSearchParams({ ...params, sid: token || "" }); const url = `${BASE_WS_URL}/ws/term?${query}`; switch (type) { diff --git a/frontend/components/containers/theme-switcher.tsx b/frontend/components/containers/theme-switcher.tsx new file mode 100644 index 0000000..fbd5871 --- /dev/null +++ b/frontend/components/containers/theme-switcher.tsx @@ -0,0 +1,29 @@ +import React from "react"; +import { Button, GetProps } from "tamagui"; +import Icons from "../ui/icons"; +import useThemeStore from "@/stores/theme"; + +type Props = GetProps & { + iconSize?: number; +}; + +const ThemeSwitcher = ({ iconSize = 24, ...props }: Props) => { + const { theme, toggle } = useThemeStore(); + + return ( + + + + + + + ); +} diff --git a/frontend/pages/auth/schema.ts b/frontend/pages/auth/schema.ts new file mode 100644 index 0000000..9e5daf9 --- /dev/null +++ b/frontend/pages/auth/schema.ts @@ -0,0 +1,10 @@ +import { z } from "zod"; + +export const loginSchema = z.object({ + username: z.string(), + password: z.string(), +}); + +export const loginResultSchema = z.object({ + sessionId: z.string().min(40), +}); diff --git a/frontend/pages/server/page.tsx b/frontend/pages/server/page.tsx new file mode 100644 index 0000000..0baa84b --- /dev/null +++ b/frontend/pages/server/page.tsx @@ -0,0 +1,78 @@ +import { View, Text, ScrollView, Card } from "tamagui"; +import React from "react"; +import FormField from "@/components/ui/form"; +import { InputField } from "@/components/ui/input"; +import { useZForm } from "@/hooks/useZForm"; +import { getServerResultSchema, serverSchema } from "./schema"; +import { router, Stack } from "expo-router"; +import Button from "@/components/ui/button"; +import ThemeSwitcher from "@/components/containers/theme-switcher"; +import { useMutation } from "@tanstack/react-query"; +import { ofetch } from "ofetch"; +import { z } from "zod"; +import { ErrorAlert } from "@/components/ui/alert"; +import { addServer } from "@/stores/app"; + +export default function ServerPage() { + const form = useZForm(serverSchema); + + const serverConnect = useMutation({ + mutationFn: async (body: z.infer) => { + const res = await ofetch(body.url + "/server"); + const { data } = getServerResultSchema.safeParse(res); + if (!data) { + throw new Error("Invalid server"); + } + return data; + }, + onSuccess(data, payload) { + addServer({ url: payload.url, name: data.name }, true); + router.replace("/auth/login"); + }, + }); + + const onSubmit = form.handleSubmit((values) => { + serverConnect.mutate(values); + }); + + return ( + <> + ( + + ), + }} + /> + + + + Connect to Server + + + + + + + + + + + + ); +} diff --git a/frontend/pages/server/schema.ts b/frontend/pages/server/schema.ts new file mode 100644 index 0000000..b8adceb --- /dev/null +++ b/frontend/pages/server/schema.ts @@ -0,0 +1,10 @@ +import { z } from "zod"; + +export const serverSchema = z.object({ + url: z.string().url("Invalid URL"), +}); + +export const getServerResultSchema = z.object({ + name: z.string(), + version: z.string().min(1), +}); diff --git a/frontend/stores/app.ts b/frontend/stores/app.ts new file mode 100644 index 0000000..ad0fb09 --- /dev/null +++ b/frontend/stores/app.ts @@ -0,0 +1,62 @@ +import { createStore, useStore } from "zustand"; +import { persist, createJSONStorage } from "zustand/middleware"; +import AsyncStorage from "@react-native-async-storage/async-storage"; + +type AppServer = { + name?: string; + url: string; +}; + +type AppStore = { + servers: AppServer[]; + curServerIdx?: number | null; +}; + +const appStore = createStore( + persist( + () => ({ + servers: [], + curServerIdx: null, + }), + { + name: "vaulterm:app", + storage: createJSONStorage(() => AsyncStorage), + } + ) +); + +export function addServer(srv: AppServer, setActive?: boolean) { + const curServers = appStore.getState().servers; + const isExist = curServers.findIndex((s) => s.url === srv.url); + + if (isExist >= 0) { + setActiveServer(isExist); + return; + } + + appStore.setState((state) => ({ + servers: [...state.servers, srv], + curServerIdx: setActive ? state.servers.length : state.curServerIdx, + })); +} + +export function removeServer(idx: number) { + appStore.setState((state) => ({ + servers: state.servers.filter((_, i) => i !== idx), + curServerIdx: state.curServerIdx === idx ? null : state.curServerIdx, + })); +} + +export function setActiveServer(idx: number) { + appStore.setState({ curServerIdx: idx }); +} + +export const useAppStore = () => { + const state = useStore(appStore); + const curServer = + state.curServerIdx != null ? state.servers[state.curServerIdx] : null; + + return { ...state, curServer }; +}; + +export default appStore; diff --git a/frontend/stores/auth.ts b/frontend/stores/auth.ts index d4facf4..9d812e5 100644 --- a/frontend/stores/auth.ts +++ b/frontend/stores/auth.ts @@ -12,7 +12,7 @@ const authStore = createStore( token: null, }), { - name: "auth", + name: "vaulterm:auth", storage: createJSONStorage(() => AsyncStorage), } ) diff --git a/frontend/stores/terminal-sessions.ts b/frontend/stores/terminal-sessions.ts index 992bb2f..55e47be 100644 --- a/frontend/stores/terminal-sessions.ts +++ b/frontend/stores/terminal-sessions.ts @@ -38,6 +38,9 @@ export const useTermSession = create( set({ curSession: idx }); }, }), - { name: "term-sessions", storage: createJSONStorage(() => AsyncStorage) } + { + name: "vaulterm:term-sessions", + storage: createJSONStorage(() => AsyncStorage), + } ) ); diff --git a/frontend/stores/theme.ts b/frontend/stores/theme.ts index 105fef7..9ae291f 100644 --- a/frontend/stores/theme.ts +++ b/frontend/stores/theme.ts @@ -20,7 +20,7 @@ const useThemeStore = create( }, }), { - name: "theme", + name: "vaulterm:theme", storage: createJSONStorage(() => AsyncStorage), } ) diff --git a/server/app/app.go b/server/app/app.go index 498e194..3924f2d 100644 --- a/server/app/app.go +++ b/server/app/app.go @@ -5,10 +5,8 @@ import ( "github.com/gofiber/fiber/v2/middleware/cors" "github.com/joho/godotenv" "rul.sh/vaulterm/app/auth" - "rul.sh/vaulterm/app/hosts" - "rul.sh/vaulterm/app/keychains" - "rul.sh/vaulterm/app/ws" "rul.sh/vaulterm/db" + "rul.sh/vaulterm/middleware" ) func NewApp() *fiber.App { @@ -18,20 +16,26 @@ func NewApp() *fiber.App { // Create fiber app app := fiber.New(fiber.Config{ErrorHandler: ErrorHandler}) - - // Middlewares app.Use(cors.New()) - // Init app routes - auth.Router(app) - hosts.Router(app) - keychains.Router(app) - ws.Router(app) + // Server info + app.Get("/server", func(c *fiber.Ctx) error { + return c.JSON(&fiber.Map{ + "name": "Vaulterm", + "version": "0.0.1", + }) + }) // Health check app.Get("/health-check", func(c *fiber.Ctx) error { return c.SendString("OK") }) + app.Use(middleware.Auth) + auth.Router(app) + + app.Use(middleware.Protected()) + InitRouter(app) + return app } diff --git a/server/app/auth/repository.go b/server/app/auth/repository.go index 4c3c119..df3cf52 100644 --- a/server/app/auth/repository.go +++ b/server/app/auth/repository.go @@ -9,7 +9,7 @@ import ( type Auth struct{ db *gorm.DB } -func NewAuthRepository() *Auth { +func NewRepository() *Auth { return &Auth{db: db.Get()} } diff --git a/server/app/auth/router.go b/server/app/auth/router.go index 823d4f9..fbb6352 100644 --- a/server/app/auth/router.go +++ b/server/app/auth/router.go @@ -1,10 +1,9 @@ package auth import ( - "strings" - "github.com/gofiber/fiber/v2" "rul.sh/vaulterm/lib" + "rul.sh/vaulterm/middleware" "rul.sh/vaulterm/utils" ) @@ -12,12 +11,12 @@ func Router(app *fiber.App) { router := app.Group("/auth") router.Post("/login", login) - router.Get("/user", getUser) - router.Post("/logout", logout) + router.Get("/user", middleware.Protected(), getUser) + router.Post("/logout", middleware.Protected(), logout) } func login(c *fiber.Ctx) error { - repo := NewAuthRepository() + repo := NewRepository() var body LoginSchema if err := c.BodyParser(&body); err != nil { @@ -54,32 +53,15 @@ func login(c *fiber.Ctx) error { } func getUser(c *fiber.Ctx) error { - auth := c.Get("Authorization") - var sessionId string - - if auth != "" { - sessionId = strings.Split(auth, " ")[1] - } - - repo := NewAuthRepository() - session, err := repo.GetSession(sessionId) - if err != nil { - return utils.ResponseError(c, err, 500) - } - - return c.JSON(session) + user := utils.GetUser(c) + return c.JSON(user) } func logout(c *fiber.Ctx) error { - auth := c.Get("Authorization") force := c.Query("force") - var sessionId string + sessionId := c.Locals("sessionId").(string) - if auth != "" { - sessionId = strings.Split(auth, " ")[1] - } - - repo := NewAuthRepository() + repo := NewRepository() err := repo.RemoveUserSession(sessionId, force == "true") if err != nil { diff --git a/server/app/hosts/repository.go b/server/app/hosts/repository.go index 0b88aa5..6d97aae 100644 --- a/server/app/hosts/repository.go +++ b/server/app/hosts/repository.go @@ -4,25 +4,36 @@ import ( "gorm.io/gorm" "rul.sh/vaulterm/db" "rul.sh/vaulterm/models" + "rul.sh/vaulterm/utils" ) -type Hosts struct{ db *gorm.DB } +type Hosts struct { + db *gorm.DB + User *utils.UserContext +} -func NewRepository() *Hosts { - return &Hosts{db: db.Get()} +func NewRepository(r *Hosts) *Hosts { + if r == nil { + r = &Hosts{} + } + r.db = db.Get() + return r } func (r *Hosts) GetAll() ([]*models.Host, error) { + query := r.ACL(r.db.Order("id DESC")) + var rows []*models.Host - ret := r.db.Order("id DESC").Find(&rows) + ret := query.Find(&rows) return rows, ret.Error } func (r *Hosts) Get(id string) (*models.HostDecrypted, error) { - var host models.Host - ret := r.db.Joins("Key").Joins("AltKey").Where("hosts.id = ?", id).First(&host) + query := r.ACL(r.db) + var host models.Host + ret := query.Joins("Key").Joins("AltKey").Where("hosts.id = ?", id).First(&host) if ret.Error != nil { return nil, ret.Error } @@ -37,12 +48,13 @@ func (r *Hosts) Get(id string) (*models.HostDecrypted, error) { func (r *Hosts) Exists(id string) (bool, error) { var count int64 - ret := r.db.Model(&models.Host{}).Where("id = ?", id).Count(&count) + ret := r.ACL(r.db.Model(&models.Host{}).Where("id = ?", id)).Count(&count) return count > 0, ret.Error } func (r *Hosts) Delete(id string) error { - return r.db.Delete(&models.Host{Model: models.Model{ID: id}}).Error + query := r.ACL(r.db) + return query.Delete(&models.Host{Model: models.Model{ID: id}}).Error } func (r *Hosts) Create(item *models.Host) error { @@ -50,5 +62,15 @@ func (r *Hosts) Create(item *models.Host) error { } func (r *Hosts) Update(id string, item *models.Host) error { - return r.db.Where("id = ?", id).Updates(item).Error + query := r.ACL(r.db.Where("id = ?", id)) + + return query.Updates(item).Error +} + +func (r *Hosts) ACL(query *gorm.DB) *gorm.DB { + if r.User.IsAdmin { + return query + } + + return query.Where("hosts.owner_id = ?", r.User.ID) } diff --git a/server/app/hosts/router.go b/server/app/hosts/router.go index 91b9476..1481c0f 100644 --- a/server/app/hosts/router.go +++ b/server/app/hosts/router.go @@ -9,7 +9,7 @@ import ( "rul.sh/vaulterm/utils" ) -func Router(app *fiber.App) { +func Router(app fiber.Router) { router := app.Group("/hosts") router.Get("/", getAll) @@ -19,7 +19,9 @@ func Router(app *fiber.App) { } func getAll(c *fiber.Ctx) error { - repo := NewRepository() + user := utils.GetUser(c) + repo := NewRepository(&Hosts{User: user}) + rows, err := repo.GetAll() if err != nil { return utils.ResponseError(c, err, 500) @@ -36,8 +38,11 @@ func create(c *fiber.Ctx) error { return utils.ResponseError(c, err, 500) } - repo := NewRepository() + user := utils.GetUser(c) + repo := NewRepository(&Hosts{User: user}) + item := &models.Host{ + OwnerID: user.ID, Type: body.Type, Label: body.Label, Host: body.Host, @@ -48,7 +53,7 @@ func create(c *fiber.Ctx) error { AltKeyID: body.AltKeyID, } - osName, err := tryConnect(item) + osName, err := tryConnect(c, item) if err != nil { return utils.ResponseError(c, fmt.Errorf("cannot connect to the host: %s", err), 500) } @@ -67,7 +72,8 @@ func update(c *fiber.Ctx) error { return utils.ResponseError(c, err, 500) } - repo := NewRepository() + user := utils.GetUser(c) + repo := NewRepository(&Hosts{User: user}) id := c.Params("id") exist, _ := repo.Exists(id) @@ -87,7 +93,7 @@ func update(c *fiber.Ctx) error { AltKeyID: body.AltKeyID, } - osName, err := tryConnect(item) + osName, err := tryConnect(c, item) if err != nil { return utils.ResponseError(c, fmt.Errorf("cannot connect to the host: %s", err), 500) } @@ -101,7 +107,8 @@ func update(c *fiber.Ctx) error { } func delete(c *fiber.Ctx) error { - repo := NewRepository() + user := utils.GetUser(c) + repo := NewRepository(&Hosts{User: user}) id := c.Params("id") exist, _ := repo.Exists(id) diff --git a/server/app/hosts/utils.go b/server/app/hosts/utils.go index 837edab..d953d7a 100644 --- a/server/app/hosts/utils.go +++ b/server/app/hosts/utils.go @@ -3,13 +3,16 @@ package hosts import ( "fmt" + "github.com/gofiber/fiber/v2" "rul.sh/vaulterm/app/keychains" "rul.sh/vaulterm/lib" "rul.sh/vaulterm/models" + "rul.sh/vaulterm/utils" ) -func tryConnect(host *models.Host) (string, error) { - keyRepo := keychains.NewRepository() +func tryConnect(c *fiber.Ctx, host *models.Host) (string, error) { + user := utils.GetUser(c) + keyRepo := keychains.NewRepository(&keychains.Keychains{User: user}) var key map[string]interface{} var altKey map[string]interface{} diff --git a/server/app/keychains/repository.go b/server/app/keychains/repository.go index 69dccc8..5cfe4fc 100644 --- a/server/app/keychains/repository.go +++ b/server/app/keychains/repository.go @@ -4,18 +4,27 @@ import ( "gorm.io/gorm" "rul.sh/vaulterm/db" "rul.sh/vaulterm/models" + "rul.sh/vaulterm/utils" ) -type Keychains struct{ db *gorm.DB } +type Keychains struct { + db *gorm.DB + User *utils.UserContext +} -func NewRepository() *Keychains { - return &Keychains{db: db.Get()} +func NewRepository(r *Keychains) *Keychains { + if r == nil { + r = &Keychains{} + } + r.db = db.Get() + return r } func (r *Keychains) GetAll() ([]*models.Keychain, error) { var rows []*models.Keychain - ret := r.db.Order("created_at DESC").Find(&rows) + query := r.ACL(r.db.Order("created_at DESC")) + ret := query.Find(&rows) return rows, ret.Error } @@ -25,7 +34,9 @@ func (r *Keychains) Create(item *models.Keychain) error { func (r *Keychains) Get(id string) (*models.Keychain, error) { var keychain models.Keychain - if err := r.db.Where("id = ?", id).First(&keychain).Error; err != nil { + query := r.ACL(r.db.Where("id = ?", id)) + + if err := query.First(&keychain).Error; err != nil { return nil, err } @@ -34,7 +45,8 @@ func (r *Keychains) Get(id string) (*models.Keychain, error) { func (r *Keychains) Exists(id string) (bool, error) { var count int64 - ret := r.db.Model(&models.Keychain{}).Where("id = ?", id).Count(&count) + query := r.ACL(r.db.Model(&models.Keychain{}).Where("id = ?", id)) + ret := query.Count(&count) return count > 0, ret.Error } @@ -58,5 +70,14 @@ func (r *Keychains) GetDecrypted(id string) (*KeychainDecrypted, error) { } func (r *Keychains) Update(id string, item *models.Keychain) error { - return r.db.Where("id = ?", id).Updates(item).Error + query := r.ACL(r.db.Where("id = ?", id)) + return query.Updates(item).Error +} + +func (r *Keychains) ACL(query *gorm.DB) *gorm.DB { + if r.User.IsAdmin { + return query + } + + return query.Where("keychains.owner_id = ?", r.User.ID) } diff --git a/server/app/keychains/router.go b/server/app/keychains/router.go index 9f1e543..8e200fa 100644 --- a/server/app/keychains/router.go +++ b/server/app/keychains/router.go @@ -9,7 +9,7 @@ import ( "rul.sh/vaulterm/utils" ) -func Router(app *fiber.App) { +func Router(app fiber.Router) { router := app.Group("/keychains") router.Get("/", getAll) @@ -25,7 +25,9 @@ type GetAllResult struct { func getAll(c *fiber.Ctx) error { withData := c.Query("withData") - repo := NewRepository() + user := utils.GetUser(c) + repo := NewRepository(&Keychains{User: user}) + rows, err := repo.GetAll() if err != nil { return utils.ResponseError(c, err, 500) @@ -62,11 +64,13 @@ func create(c *fiber.Ctx) error { return utils.ResponseError(c, err, 500) } - repo := NewRepository() + user := utils.GetUser(c) + repo := NewRepository(&Keychains{User: user}) item := &models.Keychain{ - Type: body.Type, - Label: body.Label, + OwnerID: user.ID, + Type: body.Type, + Label: body.Label, } if err := item.EncryptData(body.Data); err != nil { @@ -86,7 +90,9 @@ func update(c *fiber.Ctx) error { return utils.ResponseError(c, err, 500) } - repo := NewRepository() + user := utils.GetUser(c) + repo := NewRepository(&Keychains{User: user}) + id := c.Params("id") exist, _ := repo.Exists(id) diff --git a/server/app/router.go b/server/app/router.go new file mode 100644 index 0000000..94751ec --- /dev/null +++ b/server/app/router.go @@ -0,0 +1,23 @@ +package app + +import ( + "github.com/gofiber/fiber/v2" + "rul.sh/vaulterm/app/hosts" + "rul.sh/vaulterm/app/keychains" + "rul.sh/vaulterm/app/ws" +) + +func InitRouter(app *fiber.App) { + // App route list + routes := []Router{ + hosts.Router, + keychains.Router, + ws.Router, + } + + for _, route := range routes { + route(app) + } +} + +type Router func(app fiber.Router) diff --git a/server/app/ws/router.go b/server/app/ws/router.go index 9891caf..d0753fa 100644 --- a/server/app/ws/router.go +++ b/server/app/ws/router.go @@ -5,7 +5,7 @@ import ( "github.com/gofiber/fiber/v2" ) -func Router(app *fiber.App) { +func Router(app fiber.Router) { router := app.Group("/ws") router.Use(func(c *fiber.Ctx) error { diff --git a/server/app/ws/term.go b/server/app/ws/term.go index e82e4f3..f62a005 100644 --- a/server/app/ws/term.go +++ b/server/app/ws/term.go @@ -13,7 +13,8 @@ import ( func HandleTerm(c *websocket.Conn) { hostId := c.Query("hostId") - hostRepo := hosts.NewRepository() + user := utils.GetUserWs(c) + hostRepo := hosts.NewRepository(&hosts.Hosts{User: user}) data, err := hostRepo.Get(hostId) if data == nil { diff --git a/server/middleware/auth.go b/server/middleware/auth.go new file mode 100644 index 0000000..559c7e7 --- /dev/null +++ b/server/middleware/auth.go @@ -0,0 +1,51 @@ +package middleware + +import ( + "strings" + + "github.com/gofiber/fiber/v2" + "rul.sh/vaulterm/db" + "rul.sh/vaulterm/models" +) + +func Auth(c *fiber.Ctx) error { + authHeader := c.Get("Authorization") + var sessionId string + + if authHeader != "" { + sessionId = strings.Split(authHeader, " ")[1] + } + + if strings.HasPrefix(c.Path(), "/ws") && sessionId == "" { + sessionId = c.Query("sid") + } + + session, _ := GetUserSession(sessionId) + + if session != nil && session.User.ID != "" { + c.Locals("user", &session.User) + c.Locals("sessionId", sessionId) + } + + return c.Next() +} + +func GetUserSession(sessionId string) (*models.UserSession, error) { + var session models.UserSession + res := db.Get().Joins("User").Where("user_sessions.id = ?", sessionId).First(&session) + return &session, res.Error +} + +func Protected() func(c *fiber.Ctx) error { + return func(c *fiber.Ctx) error { + user := c.Locals("user") + if user == nil { + return &fiber.Error{ + Code: fiber.StatusUnauthorized, + Message: "Unauthorized", + } + } + return c.Next() + } + +} diff --git a/server/models/host.go b/server/models/host.go index 404223b..9bdbb80 100644 --- a/server/models/host.go +++ b/server/models/host.go @@ -14,6 +14,9 @@ const ( type Host struct { Model + OwnerID string `json:"userId" gorm:"index:hosts_owner_id_idx;type:varchar(26)"` + Owner User `json:"user" gorm:"foreignKey:OwnerID"` + Type string `json:"type" gorm:"not null;index:hosts_type_idx;type:varchar(16)"` Label string `json:"label"` Host string `json:"host" gorm:"type:varchar(64)"` @@ -54,3 +57,14 @@ func (h *Host) DecryptKeys() (*HostDecrypted, error) { return res, nil } + +type HostHasAccessOptions struct { + UserID string +} + +func (h *Host) HasAccess(o HostHasAccessOptions) bool { + if o.UserID == h.OwnerID { + return true + } + return false +} diff --git a/server/models/keychain.go b/server/models/keychain.go index b002e0d..9ae6f17 100644 --- a/server/models/keychain.go +++ b/server/models/keychain.go @@ -16,6 +16,9 @@ const ( type Keychain struct { Model + OwnerID string `json:"userId" gorm:"index:hosts_owner_id_idx;type:varchar(26)"` + Owner User `json:"user" gorm:"foreignKey:OwnerID"` + Label string `json:"label"` Type string `json:"type" gorm:"not null;index:keychains_type_idx;type:varchar(12)"` Data string `json:"-" gorm:"type:text"` diff --git a/server/utils/context.go b/server/utils/context.go new file mode 100644 index 0000000..e58d5a8 --- /dev/null +++ b/server/utils/context.go @@ -0,0 +1,35 @@ +package utils + +import ( + "github.com/gofiber/contrib/websocket" + "github.com/gofiber/fiber/v2" + "rul.sh/vaulterm/models" +) + +type UserContext struct { + *models.User + IsAdmin bool +} + +func getUserData(user *models.User) *UserContext { + isAdmin := false + + if user.Role == models.UserRoleAdmin { + isAdmin = true + } + + return &UserContext{ + User: user, + IsAdmin: isAdmin, + } +} + +func GetUser(c *fiber.Ctx) *UserContext { + user := c.Locals("user").(*models.User) + return getUserData(user) +} + +func GetUserWs(c *websocket.Conn) *UserContext { + user := c.Locals("user").(*models.User) + return getUserData(user) +}