diff --git a/frontend/app/(drawer)/_layout.tsx b/frontend/app/(drawer)/_layout.tsx index be5a791..fd06a93 100644 --- a/frontend/app/(drawer)/_layout.tsx +++ b/frontend/app/(drawer)/_layout.tsx @@ -2,11 +2,17 @@ import { GestureHandlerRootView } from "react-native-gesture-handler"; import { Drawer } from "expo-router/drawer"; import React from "react"; import { useMedia } from "tamagui"; -import DrawerContent from "@/components/containers/drawer"; +import DrawerContent, { + DrawerNavigationOptions, +} from "@/components/containers/drawer"; import Icons from "@/components/ui/icons"; +import { useUser } from "@/hooks/useUser"; +import { useTeamId } from "@/stores/auth"; export default function Layout() { const media = useMedia(); + const teamId = useTeamId(); + const user = useUser(); return ( @@ -29,12 +35,15 @@ export default function Layout() { /> ( - - ), - }} + options={ + { + title: "Keychains", + hidden: teamId && !user?.teamCanWrite(teamId), + drawerIcon: ({ size, color }) => ( + + ), + } as DrawerNavigationOptions + } /> { return ( <> @@ -61,7 +66,12 @@ const DrawerItemList = ({ } }; - const { title, drawerLabel, drawerIcon } = descriptors[route.key].options; + const { title, drawerLabel, drawerIcon, hidden } = descriptors[route.key] + .options as DrawerNavigationOptions; + + if (hidden) { + return null; + } return ( diff --git a/frontend/components/containers/user-menu-button.tsx b/frontend/components/containers/user-menu-button.tsx index db592c3..51d6569 100644 --- a/frontend/components/containers/user-menu-button.tsx +++ b/frontend/components/containers/user-menu-button.tsx @@ -11,11 +11,13 @@ import { } from "tamagui"; import MenuButton from "../ui/menu-button"; import Icons from "../ui/icons"; -import { logout } from "@/stores/auth"; +import { logout, setTeam, useTeamId } from "@/stores/auth"; import { useUser } from "@/hooks/useUser"; const UserMenuButton = () => { const user = useUser(); + const teamId = useTeamId(); + const team = user?.teams?.find((t: any) => t.id === teamId); return ( { {user?.name} - Personal + {team ? `${team.icon} ${team.name}` : "Personal"} @@ -61,6 +63,7 @@ const UserMenuButton = () => { const TeamsMenu = () => { const media = useMedia(); const user = useUser(); + const teamId = useTeamId(); const teams = user?.teams || []; return ( @@ -73,20 +76,30 @@ const TeamsMenu = () => { console.log("logout")} icon={} title="Teams" iconAfter={} /> } > - } - title="Personal" - /> + {teamId != null && ( + } + title="Personal" + onPress={() => setTeam(null)} + /> + )} {teams.map((team: any) => ( - {team.icon}} title={team.name} /> + {team.icon}} + iconAfter={ + teamId === team.id ? : undefined + } + title={team.name} + onPress={() => setTeam(team.id)} + /> ))} {teams.length > 0 && } diff --git a/frontend/hooks/useUser.ts b/frontend/hooks/useUser.ts index 08cf7f1..2c397db 100644 --- a/frontend/hooks/useUser.ts +++ b/frontend/hooks/useUser.ts @@ -6,5 +6,32 @@ export const useUser = () => { queryKey: ["auth", "user"], queryFn: authRepo.getUser, }); - return user; + + if (!user) { + return null; + } + + function getTeamRole(teamId?: string | null) { + if (!user.teams?.length) { + return false; + } + const team = user.teams.find((i: any) => i.id === teamId); + return team?.role; + } + + function isInTeam(teamId?: string | null) { + return getTeamRole(teamId) != null; + } + + function teamCanWrite(teamId?: string | null) { + const role = getTeamRole(teamId); + return ["admin", "owner"].includes(role); + } + + return { + ...user, + getTeamRole, + isInTeam, + teamCanWrite, + }; }; diff --git a/frontend/lib/api.ts b/frontend/lib/api.ts index b276f4b..7386624 100644 --- a/frontend/lib/api.ts +++ b/frontend/lib/api.ts @@ -1,6 +1,5 @@ import { getCurrentServer } from "@/stores/app"; import authStore from "@/stores/auth"; -import { QueryClient } from "@tanstack/react-query"; import { ofetch } from "ofetch"; const api = ofetch.create({ @@ -13,9 +12,13 @@ const api = ofetch.create({ // set server url config.options.baseURL = server.url; - const authToken = authStore.getState().token; - if (authToken) { - config.options.headers.set("Authorization", `Bearer ${authToken}`); + const { token, teamId } = authStore.getState(); + + if (token) { + config.options.headers.set("Authorization", `Bearer ${token}`); + } + if (teamId) { + config.options.headers.set("X-Team-Id", teamId); } }, onResponseError: (error) => { @@ -31,6 +34,4 @@ const api = ofetch.create({ }, }); -export const queryClient = new QueryClient(); - export default api; diff --git a/frontend/lib/queryClient.ts b/frontend/lib/queryClient.ts new file mode 100644 index 0000000..db2b63d --- /dev/null +++ b/frontend/lib/queryClient.ts @@ -0,0 +1,5 @@ +import { QueryClient } from "@tanstack/react-query"; + +const queryClient = new QueryClient(); + +export default queryClient; diff --git a/frontend/pages/auth/login.tsx b/frontend/pages/auth/login.tsx index 4563872..52e4370 100644 --- a/frontend/pages/auth/login.tsx +++ b/frontend/pages/auth/login.tsx @@ -47,6 +47,7 @@ export default function LoginPage() { marginHorizontal: "auto", }, title: "Login", + headerTitle: "", headerRight: () => ( ), diff --git a/frontend/pages/hosts/components/host-list.tsx b/frontend/pages/hosts/components/host-list.tsx index d1a949f..f44a09c 100644 --- a/frontend/pages/hosts/components/host-list.tsx +++ b/frontend/pages/hosts/components/host-list.tsx @@ -8,6 +8,7 @@ import { useTermSession } from "@/stores/terminal-sessions"; import { hostFormModal } from "./form"; import GridView from "@/components/ui/grid-view"; import HostItem from "./host-item"; +import { useHosts } from "../hooks/query"; type HostsListProps = { allowEdit?: boolean; @@ -18,11 +19,7 @@ const HostList = ({ allowEdit = true }: HostsListProps) => { const navigation = useNavigation(); const [search, setSearch] = useState(""); - const hosts = useQuery({ - queryKey: ["hosts"], - queryFn: () => api("/hosts"), - select: (i) => i.rows, - }); + const hosts = useHosts(); const hostsList = useMemo(() => { let items = hosts.data || []; diff --git a/frontend/pages/hosts/hooks/query.ts b/frontend/pages/hosts/hooks/query.ts index 6943453..2acfd5c 100644 --- a/frontend/pages/hosts/hooks/query.ts +++ b/frontend/pages/hosts/hooks/query.ts @@ -1,8 +1,19 @@ import { useMutation, useQuery } from "@tanstack/react-query"; import { FormSchema } from "../schema/form"; -import api, { queryClient } from "@/lib/api"; +import api from "@/lib/api"; import { useMemo } from "react"; import { useKeychains } from "@/pages/keychains/hooks/query"; +import queryClient from "@/lib/queryClient"; +import { useTeamId } from "@/stores/auth"; + +export const useHosts = () => { + const teamId = useTeamId(); + return useQuery({ + queryKey: ["hosts", teamId], + queryFn: () => api("/hosts", { params: { teamId } }), + select: (i) => i.rows, + }); +}; export const useKeychainsOptions = () => { const keys = useKeychains(); @@ -20,8 +31,11 @@ export const useKeychainsOptions = () => { }; export const useSaveHost = () => { + const teamId = useTeamId(); + return useMutation({ - mutationFn: async (body: FormSchema) => { + mutationFn: async (payload: FormSchema) => { + const body = { teamId, ...payload }; return body.id ? api(`/hosts/${body.id}`, { method: "PUT", body }) : api(`/hosts`, { method: "POST", body }); diff --git a/frontend/pages/keychains/hooks/query.ts b/frontend/pages/keychains/hooks/query.ts index 64c7d92..1affc11 100644 --- a/frontend/pages/keychains/hooks/query.ts +++ b/frontend/pages/keychains/hooks/query.ts @@ -1,8 +1,13 @@ -import api, { queryClient } from "@/lib/api"; +import api from "@/lib/api"; import { useMutation, useQuery } from "@tanstack/react-query"; import { FormSchema } from "../schema/form"; +import queryClient from "@/lib/queryClient"; +import { useTeamId } from "@/stores/auth"; + +export const useKeychains = (params?: any) => { + const teamId = useTeamId(); + const query = { teamId, ...params }; -export const useKeychains = (query?: any) => { return useQuery({ queryKey: ["keychains", query], queryFn: () => api("/keychains", { query }), @@ -11,8 +16,11 @@ export const useKeychains = (query?: any) => { }; export const useSaveKeychain = () => { + const teamId = useTeamId(); + return useMutation({ - mutationFn: async (body: FormSchema) => { + mutationFn: async (payload: FormSchema) => { + const body = { teamId, ...payload }; return body.id ? api(`/keychains/${body.id}`, { method: "PUT", body }) : api(`/keychains`, { method: "POST", body }); diff --git a/frontend/stores/auth.ts b/frontend/stores/auth.ts index eeb9303..0c49dd4 100644 --- a/frontend/stores/auth.ts +++ b/frontend/stores/auth.ts @@ -2,15 +2,18 @@ import { createStore, useStore } from "zustand"; import { persist, createJSONStorage } from "zustand/middleware"; import AsyncStorage from "@react-native-async-storage/async-storage"; import termSessionStore from "./terminal-sessions"; +import queryClient from "@/lib/queryClient"; type AuthStore = { - token?: string | null; + token: string | null; + teamId: string | null; }; const authStore = createStore( persist( () => ({ token: null, + teamId: null, }), { name: "vaulterm:auth", @@ -24,9 +27,18 @@ export const useAuthStore = () => { return { ...state, isLoggedIn: state.token != null }; }; +export const setTeam = (teamId: string | null) => { + authStore.setState({ teamId }); + queryClient.invalidateQueries(); +}; + export const logout = () => { - authStore.setState({ token: null }); + authStore.setState({ token: null, teamId: null }); termSessionStore.setState({ sessions: [], curSession: 0 }); }; +export const useTeamId = () => { + return useStore(authStore, (i) => i.teamId); +}; + export default authStore; diff --git a/server/app/auth/router.go b/server/app/auth/router.go index fbb6352..55579d6 100644 --- a/server/app/auth/router.go +++ b/server/app/auth/router.go @@ -54,7 +54,21 @@ func login(c *fiber.Ctx) error { func getUser(c *fiber.Ctx) error { user := utils.GetUser(c) - return c.JSON(user) + teams := []TeamWithRole{} + + for _, item := range user.Teams { + teams = append(teams, TeamWithRole{ + ID: item.TeamID, + Name: item.Team.Name, + Icon: item.Team.Icon, + Role: item.Role, + }) + } + + return c.JSON(&GetUserResult{ + AuthUser: *user, + Teams: teams, + }) } func logout(c *fiber.Ctx) error { diff --git a/server/app/auth/schema.go b/server/app/auth/schema.go index 9376ffa..6b47e47 100644 --- a/server/app/auth/schema.go +++ b/server/app/auth/schema.go @@ -1,6 +1,20 @@ package auth +import "rul.sh/vaulterm/middleware" + type LoginSchema struct { Username string `json:"username"` Password string `json:"password"` } + +type TeamWithRole struct { + ID string `json:"id"` + Name string `json:"name"` + Icon string `json:"icon"` + Role string `json:"role"` +} + +type GetUserResult struct { + middleware.AuthUser + Teams []TeamWithRole `json:"teams"` +} diff --git a/server/app/hosts/repository.go b/server/app/hosts/repository.go index 6d97aae..4e9830b 100644 --- a/server/app/hosts/repository.go +++ b/server/app/hosts/repository.go @@ -20,8 +20,14 @@ func NewRepository(r *Hosts) *Hosts { return r } -func (r *Hosts) GetAll() ([]*models.Host, error) { - query := r.ACL(r.db.Order("id DESC")) +func (r *Hosts) GetAll(opt GetAllOpt) ([]*models.Host, error) { + query := r.db.Order("id DESC") + + if opt.TeamID != "" { + query = query.Where("hosts.team_id = ?", opt.TeamID) + } else { + query = query.Where("hosts.owner_id = ? AND hosts.team_id IS NULL", r.User.ID) + } var rows []*models.Host ret := query.Find(&rows) @@ -30,10 +36,8 @@ func (r *Hosts) GetAll() ([]*models.Host, error) { } func (r *Hosts) Get(id string) (*models.HostDecrypted, error) { - query := r.ACL(r.db) - var host models.Host - ret := query.Joins("Key").Joins("AltKey").Where("hosts.id = ?", id).First(&host) + ret := r.db.Joins("Key").Joins("AltKey").Where("hosts.id = ?", id).First(&host) if ret.Error != nil { return nil, ret.Error } @@ -48,13 +52,12 @@ func (r *Hosts) Get(id string) (*models.HostDecrypted, error) { func (r *Hosts) Exists(id string) (bool, error) { var count int64 - ret := r.ACL(r.db.Model(&models.Host{}).Where("id = ?", id)).Count(&count) + ret := r.db.Model(&models.Host{}).Where("id = ?", id).Count(&count) return count > 0, ret.Error } func (r *Hosts) Delete(id string) error { - query := r.ACL(r.db) - return query.Delete(&models.Host{Model: models.Model{ID: id}}).Error + return r.db.Delete(&models.Host{Model: models.Model{ID: id}}).Error } func (r *Hosts) Create(item *models.Host) error { @@ -62,15 +65,5 @@ func (r *Hosts) Create(item *models.Host) error { } func (r *Hosts) Update(id string, item *models.Host) error { - query := r.ACL(r.db.Where("id = ?", id)) - - return query.Updates(item).Error -} - -func (r *Hosts) ACL(query *gorm.DB) *gorm.DB { - if r.User.IsAdmin { - return query - } - - return query.Where("hosts.owner_id = ?", r.User.ID) + return r.db.Where("id = ?", id).Updates(item).Error } diff --git a/server/app/hosts/router.go b/server/app/hosts/router.go index 1481c0f..0bb51ac 100644 --- a/server/app/hosts/router.go +++ b/server/app/hosts/router.go @@ -1,6 +1,7 @@ package hosts import ( + "errors" "fmt" "net/http" @@ -19,10 +20,15 @@ func Router(app fiber.Router) { } func getAll(c *fiber.Ctx) error { + teamId := c.Query("teamId") user := utils.GetUser(c) repo := NewRepository(&Hosts{User: user}) - rows, err := repo.GetAll() + if teamId != "" && !user.IsInTeam(&teamId) { + return utils.ResponseError(c, errors.New("no access"), 403) + } + + rows, err := repo.GetAll(GetAllOpt{TeamID: teamId}) if err != nil { return utils.ResponseError(c, err, 500) } @@ -41,8 +47,13 @@ func create(c *fiber.Ctx) error { user := utils.GetUser(c) repo := NewRepository(&Hosts{User: user}) + if body.TeamID != nil && !user.TeamCanWrite(body.TeamID) { + return utils.ResponseError(c, errors.New("no access"), 403) + } + item := &models.Host{ - OwnerID: user.ID, + OwnerID: &user.ID, + TeamID: body.TeamID, Type: body.Type, Label: body.Label, Host: body.Host, @@ -76,13 +87,17 @@ func update(c *fiber.Ctx) error { repo := NewRepository(&Hosts{User: user}) id := c.Params("id") - exist, _ := repo.Exists(id) - if !exist { - return utils.ResponseError(c, fmt.Errorf("host %s not found", id), 404) + data, _ := repo.Get(id) + if data == nil { + return utils.ResponseError(c, errors.New("host not found"), 404) + } + if !data.CanWrite(&user.User) || !user.TeamCanWrite(body.TeamID) { + return utils.ResponseError(c, errors.New("no access"), 403) } item := &models.Host{ Model: models.Model{ID: id}, + TeamID: body.TeamID, Type: body.Type, Label: body.Label, Host: body.Host, @@ -111,9 +126,12 @@ func delete(c *fiber.Ctx) error { repo := NewRepository(&Hosts{User: user}) id := c.Params("id") - exist, _ := repo.Exists(id) - if !exist { - return utils.ResponseError(c, fmt.Errorf("host %s not found", id), 404) + host, _ := repo.Get(id) + if host == nil { + return utils.ResponseError(c, errors.New("host not found"), 404) + } + if !host.CanWrite(&user.User) { + return utils.ResponseError(c, errors.New("no access"), 403) } if err := repo.Delete(id); err != nil { diff --git a/server/app/hosts/schema.go b/server/app/hosts/schema.go index 511d9df..1f96d66 100644 --- a/server/app/hosts/schema.go +++ b/server/app/hosts/schema.go @@ -9,7 +9,12 @@ type CreateHostSchema struct { Port int `json:"port"` Metadata datatypes.JSONMap `json:"metadata"` + TeamID *string `json:"teamId"` ParentID *string `json:"parentId"` KeyID *string `json:"keyId"` AltKeyID *string `json:"altKeyId"` } + +type GetAllOpt struct { + TeamID string +} diff --git a/server/app/keychains/repository.go b/server/app/keychains/repository.go index 5cfe4fc..8b6ffa9 100644 --- a/server/app/keychains/repository.go +++ b/server/app/keychains/repository.go @@ -20,10 +20,16 @@ func NewRepository(r *Keychains) *Keychains { return r } -func (r *Keychains) GetAll() ([]*models.Keychain, error) { - var rows []*models.Keychain - query := r.ACL(r.db.Order("created_at DESC")) +func (r *Keychains) GetAll(opt GetAllOpt) ([]*models.Keychain, error) { + query := r.db.Order("created_at DESC") + if opt.TeamID != "" { + query = query.Where("keychains.team_id = ?", opt.TeamID) + } else { + query = query.Where("keychains.owner_id = ? AND keychains.team_id IS NULL", r.User.ID) + } + + var rows []*models.Keychain ret := query.Find(&rows) return rows, ret.Error } @@ -34,9 +40,7 @@ func (r *Keychains) Create(item *models.Keychain) error { func (r *Keychains) Get(id string) (*models.Keychain, error) { var keychain models.Keychain - query := r.ACL(r.db.Where("id = ?", id)) - - if err := query.First(&keychain).Error; err != nil { + if err := r.db.Where("id = ?", id).First(&keychain).Error; err != nil { return nil, err } @@ -45,8 +49,7 @@ func (r *Keychains) Get(id string) (*models.Keychain, error) { func (r *Keychains) Exists(id string) (bool, error) { var count int64 - query := r.ACL(r.db.Model(&models.Keychain{}).Where("id = ?", id)) - ret := query.Count(&count) + ret := r.db.Model(&models.Keychain{}).Where("id = ?", id).Count(&count) return count > 0, ret.Error } @@ -70,14 +73,5 @@ func (r *Keychains) GetDecrypted(id string) (*KeychainDecrypted, error) { } func (r *Keychains) Update(id string, item *models.Keychain) error { - query := r.ACL(r.db.Where("id = ?", id)) - return query.Updates(item).Error -} - -func (r *Keychains) ACL(query *gorm.DB) *gorm.DB { - if r.User.IsAdmin { - return query - } - - return query.Where("keychains.owner_id = ?", r.User.ID) + return r.db.Where("id = ?", id).Updates(item).Error } diff --git a/server/app/keychains/router.go b/server/app/keychains/router.go index 8e200fa..fe0a577 100644 --- a/server/app/keychains/router.go +++ b/server/app/keychains/router.go @@ -1,7 +1,7 @@ package keychains import ( - "fmt" + "errors" "net/http" "github.com/gofiber/fiber/v2" @@ -23,17 +23,22 @@ type GetAllResult struct { } func getAll(c *fiber.Ctx) error { + teamId := c.Query("teamId") withData := c.Query("withData") user := utils.GetUser(c) repo := NewRepository(&Keychains{User: user}) - rows, err := repo.GetAll() + if teamId != "" && !user.IsInTeam(&teamId) { + return utils.ResponseError(c, errors.New("no access"), 403) + } + + rows, err := repo.GetAll(GetAllOpt{TeamID: teamId}) if err != nil { return utils.ResponseError(c, err, 500) } - if withData != "true" { + if withData != "true" || (teamId != "" && !user.TeamCanWrite(&teamId)) { return c.JSON(fiber.Map{"rows": rows}) } @@ -67,8 +72,13 @@ func create(c *fiber.Ctx) error { user := utils.GetUser(c) repo := NewRepository(&Keychains{User: user}) + if body.TeamID != nil && !user.TeamCanWrite(body.TeamID) { + return utils.ResponseError(c, errors.New("no access"), 403) + } + item := &models.Keychain{ - OwnerID: user.ID, + OwnerID: &user.ID, + TeamID: body.TeamID, Type: body.Type, Label: body.Label, } @@ -94,15 +104,18 @@ func update(c *fiber.Ctx) error { repo := NewRepository(&Keychains{User: user}) id := c.Params("id") - - exist, _ := repo.Exists(id) - if !exist { - return utils.ResponseError(c, fmt.Errorf("key %s not found", id), 404) + data, _ := repo.Get(id) + if data == nil { + return utils.ResponseError(c, errors.New("key not found"), 404) + } + if !data.CanWrite(&user.User) || !user.TeamCanWrite(body.TeamID) { + return utils.ResponseError(c, errors.New("no access"), 403) } item := &models.Keychain{ - Type: body.Type, - Label: body.Label, + TeamID: body.TeamID, + Type: body.Type, + Label: body.Label, } if err := item.EncryptData(body.Data); err != nil { diff --git a/server/app/keychains/schema.go b/server/app/keychains/schema.go index c7f7086..ad8d4b4 100644 --- a/server/app/keychains/schema.go +++ b/server/app/keychains/schema.go @@ -1,7 +1,12 @@ package keychains type CreateKeychainSchema struct { - Type string `json:"type"` - Label string `json:"label"` - Data interface{} `json:"data"` + TeamID *string `json:"teamId"` + Type string `json:"type"` + Label string `json:"label"` + Data interface{} `json:"data"` +} + +type GetAllOpt struct { + TeamID string } diff --git a/server/app/teams/repository.go b/server/app/teams/repository.go new file mode 100644 index 0000000..efe8aca --- /dev/null +++ b/server/app/teams/repository.go @@ -0,0 +1,50 @@ +package teams + +import ( + "gorm.io/gorm" + "rul.sh/vaulterm/db" + "rul.sh/vaulterm/models" + "rul.sh/vaulterm/utils" +) + +type Teams struct { + db *gorm.DB + User *utils.UserContext +} + +func NewRepository(r *Teams) *Teams { + if r == nil { + r = &Teams{} + } + r.db = db.Get() + return r +} + +func (r *Teams) GetAll() ([]*models.Team, error) { + var rows []*models.Team + ret := r.db.Order("created_at DESC").Find(&rows) + return rows, ret.Error +} + +func (r *Teams) Create(data *models.Team) error { + return r.db.Create(data).Error +} + +func (r *Teams) Get(id string) (*models.Team, error) { + var data models.Team + if err := r.db.Where("id = ?", id).First(&data).Error; err != nil { + return nil, err + } + + return &data, nil +} + +func (r *Teams) Exists(id string) (bool, error) { + var count int64 + ret := r.db.Model(&models.Team{}).Where("id = ?", id).Count(&count) + return count > 0, ret.Error +} + +func (r *Teams) Update(id string, item *models.Team) error { + return r.db.Where("id = ?", id).Updates(item).Error +} diff --git a/server/app/ws/term.go b/server/app/ws/term.go index f62a005..958f2cb 100644 --- a/server/app/ws/term.go +++ b/server/app/ws/term.go @@ -17,8 +17,8 @@ func HandleTerm(c *websocket.Conn) { hostRepo := hosts.NewRepository(&hosts.Hosts{User: user}) data, err := hostRepo.Get(hostId) - if data == nil { - log.Printf("Cannot find host! Error: %s\n", err.Error()) + if data == nil || !data.HasAccess(&user.User) { + log.Printf("Cannot find host! %v\n", err) c.WriteMessage(websocket.TextMessage, []byte("Host not found")) return } diff --git a/server/db/database.go b/server/db/database.go index c3f9e03..a44e976 100644 --- a/server/db/database.go +++ b/server/db/database.go @@ -45,7 +45,6 @@ func Init() { // Migrate the schema db.AutoMigrate(Models...) - InitModels(db) runSeeders(db) } diff --git a/server/db/models.go b/server/db/models.go index 664d957..fe3867d 100644 --- a/server/db/models.go +++ b/server/db/models.go @@ -1,9 +1,6 @@ package db import ( - "log" - - "gorm.io/gorm" "rul.sh/vaulterm/models" ) @@ -15,9 +12,3 @@ var Models = []interface{}{ &models.Team{}, &models.TeamMembers{}, } - -func InitModels(db *gorm.DB) { - if err := db.SetupJoinTable(&models.Team{}, "Members", &models.TeamMembers{}); err != nil { - log.Fatal(err) - } -} diff --git a/server/db/seeders.go b/server/db/seeders.go index 4da3a9c..11c466f 100644 --- a/server/db/seeders.go +++ b/server/db/seeders.go @@ -66,9 +66,9 @@ func seedUsers(tx *gorm.DB) error { } teamMembers := []models.TeamMembers{ - {TeamID: teams[0].ID, UserID: userList[0].ID, Role: "owner"}, - {TeamID: teams[0].ID, UserID: userList[1].ID, Role: "admin"}, - {TeamID: teams[0].ID, UserID: userList[2].ID, Role: "user"}, + {TeamID: teams[0].ID, UserID: userList[0].ID, Role: models.TeamRoleOwner}, + {TeamID: teams[0].ID, UserID: userList[1].ID, Role: models.TeamRoleAdmin}, + {TeamID: teams[0].ID, UserID: userList[2].ID, Role: models.TeamRoleMember}, } if res := tx.Create(&teamMembers); res.Error != nil { diff --git a/server/middleware/auth.go b/server/middleware/auth.go index b09c661..0b05af3 100644 --- a/server/middleware/auth.go +++ b/server/middleware/auth.go @@ -4,7 +4,6 @@ import ( "strings" "github.com/gofiber/fiber/v2" - "gorm.io/gorm" "rul.sh/vaulterm/db" "rul.sh/vaulterm/models" ) @@ -23,25 +22,35 @@ func Auth(c *fiber.Ctx) error { session, _ := GetUserSession(sessionId) - if session != nil && session.User.ID != "" { - c.Locals("user", &session.User) + if session != nil && session.ID != "" { + c.Locals("user", session) c.Locals("sessionId", sessionId) } return c.Next() } -func GetUserSession(sessionId string) (*models.UserSession, error) { - var session models.UserSession +type AuthUser struct { + models.User + SessionID string `json:"sessionId" gorm:"column:session_id"` +} + +func GetUserSession(sessionId string) (*AuthUser, error) { + var session AuthUser + res := db.Get(). - Joins("User"). - Preload("User.Teams", func(db *gorm.DB) *gorm.DB { - return db.Select("id", "name", "icon") - }). + Model(&models.User{}). + Joins("JOIN user_sessions ON user_sessions.user_id = users.id"). + Preload("Teams.Team"). + Select("users.*, user_sessions.id AS session_id"). Where("user_sessions.id = ?", sessionId). First(&session) - return &session, res.Error + if res.Error != nil || session.User.ID == "" { + return nil, res.Error + } + + return &session, nil } func Protected() func(c *fiber.Ctx) error { diff --git a/server/models/host.go b/server/models/host.go index 9bdbb80..309aa28 100644 --- a/server/models/host.go +++ b/server/models/host.go @@ -1,6 +1,8 @@ package models -import "gorm.io/datatypes" +import ( + "gorm.io/datatypes" +) const ( HostTypeSSH = "ssh" @@ -14,8 +16,10 @@ const ( type Host struct { Model - OwnerID string `json:"userId" gorm:"index:hosts_owner_id_idx;type:varchar(26)"` - Owner User `json:"user" gorm:"foreignKey:OwnerID"` + OwnerID *string `json:"userId" gorm:"type:varchar(26)"` + Owner *User `json:"user" gorm:"foreignKey:OwnerID"` + TeamID *string `json:"teamId" gorm:"type:varchar(26)"` + Team *Team `json:"team" gorm:"foreignKey:TeamID"` Type string `json:"type" gorm:"not null;index:hosts_type_idx;type:varchar(16)"` Label string `json:"label"` @@ -24,11 +28,11 @@ type Host struct { OS string `json:"os" gorm:"type:varchar(32)"` Metadata datatypes.JSONMap `json:"metadata"` - ParentID *string `json:"parentId" gorm:"index:hosts_parent_id_idx;type:varchar(26)"` + ParentID *string `json:"parentId" gorm:"type:varchar(26)"` Parent *Host `json:"parent" gorm:"foreignKey:ParentID"` - KeyID *string `json:"keyId" gorm:"index:hosts_key_id_idx"` + KeyID *string `json:"keyId" gorm:"type:varchar(26)"` Key Keychain `json:"key" gorm:"foreignKey:KeyID"` - AltKeyID *string `json:"altKeyId" gorm:"index:hosts_altkey_id_idx"` + AltKeyID *string `json:"altKeyId" gorm:"type:varchar(26)"` AltKey Keychain `json:"altKey" gorm:"foreignKey:AltKeyID"` Timestamps @@ -58,13 +62,17 @@ func (h *Host) DecryptKeys() (*HostDecrypted, error) { return res, nil } -type HostHasAccessOptions struct { - UserID string -} - -func (h *Host) HasAccess(o HostHasAccessOptions) bool { - if o.UserID == h.OwnerID { +func (h *Host) HasAccess(user *User) bool { + if user.IsAdmin() { return true } - return false + return *h.OwnerID == user.ID || user.IsInTeam(h.TeamID) +} + +func (h *Host) CanWrite(user *User) bool { + if user.IsAdmin() { + return true + } + teamRole := user.GetTeamRole(h.TeamID) + return *h.OwnerID == user.ID || teamRole == TeamRoleOwner || teamRole == TeamRoleAdmin } diff --git a/server/models/keychain.go b/server/models/keychain.go index 9ae6f17..c779d39 100644 --- a/server/models/keychain.go +++ b/server/models/keychain.go @@ -16,8 +16,10 @@ const ( type Keychain struct { Model - OwnerID string `json:"userId" gorm:"index:hosts_owner_id_idx;type:varchar(26)"` - Owner User `json:"user" gorm:"foreignKey:OwnerID"` + OwnerID *string `json:"userId" gorm:"type:varchar(26)"` + Owner *User `json:"user" gorm:"foreignKey:OwnerID"` + TeamID *string `json:"teamId" gorm:"type:varchar(26)"` + Team *Team `json:"team" gorm:"foreignKey:TeamID"` Label string `json:"label"` Type string `json:"type" gorm:"not null;index:keychains_type_idx;type:varchar(12)"` @@ -55,3 +57,18 @@ func (k *Keychain) DecryptData(data interface{}) error { return nil } + +func (k *Keychain) HasAccess(user *User) bool { + if user.IsAdmin() { + return true + } + return *k.OwnerID == user.ID || user.IsInTeam(k.TeamID) +} + +func (k *Keychain) CanWrite(user *User) bool { + if user.IsAdmin() { + return true + } + teamRole := user.GetTeamRole(k.TeamID) + return *k.OwnerID == user.ID || teamRole == TeamRoleOwner || teamRole == TeamRoleAdmin +} diff --git a/server/models/team.go b/server/models/team.go index 2d607e0..86e242b 100644 --- a/server/models/team.go +++ b/server/models/team.go @@ -2,12 +2,18 @@ package models import "time" +const ( + TeamRoleOwner = "owner" + TeamRoleAdmin = "admin" + TeamRoleMember = "member" +) + type Team struct { Model - Name string `json:"name" gorm:"type:varchar(32)"` - Icon string `json:"icon" gorm:"type:varchar(2)"` - Members []*User `json:"members" gorm:"many2many:team_members"` + Name string `json:"name" gorm:"type:varchar(32)"` + Icon string `json:"icon" gorm:"type:varchar(2)"` + Members []*TeamMembers `json:"members" gorm:"foreignKey:TeamID"` Timestamps SoftDeletes diff --git a/server/models/user.go b/server/models/user.go index cfc0b57..8a084e1 100644 --- a/server/models/user.go +++ b/server/models/user.go @@ -1,5 +1,7 @@ package models +import "slices" + const ( UserRoleUser = "user" UserRoleAdmin = "admin" @@ -14,7 +16,7 @@ type User struct { Email string `json:"email" gorm:"unique"` Role string `json:"role" gorm:"default:user;not null;index:users_role_idx;type:varchar(8)"` - Teams []*Team `json:"teams" gorm:"many2many:team_members"` + Teams []*TeamMembers `json:"teams" gorm:"foreignKey:UserID"` Timestamps SoftDeletes @@ -28,3 +30,33 @@ type UserSession struct { Timestamps SoftDeletes } + +func (u *User) IsAdmin() bool { + return u.Role == UserRoleAdmin +} + +func (u *User) GetTeamRole(teamId *string) string { + if u.IsAdmin() { + return TeamRoleAdmin + } + if teamId == nil { + return "" + } + idx := slices.IndexFunc(u.Teams, func(tm *TeamMembers) bool { + return tm.TeamID == *teamId + }) + if idx == -1 { + return "" + } + return u.Teams[idx].Role +} + +func (u *User) IsInTeam(teamId *string) bool { + role := u.GetTeamRole(teamId) + return role != "" +} + +func (u *User) TeamCanWrite(teamId *string) bool { + role := u.GetTeamRole(teamId) + return role == TeamRoleAdmin || role == TeamRoleOwner +} diff --git a/server/utils/context.go b/server/utils/context.go index a1bbbcf..859021f 100644 --- a/server/utils/context.go +++ b/server/utils/context.go @@ -3,33 +3,17 @@ package utils import ( "github.com/gofiber/contrib/websocket" "github.com/gofiber/fiber/v2" - "rul.sh/vaulterm/models" + "rul.sh/vaulterm/middleware" ) -type UserContext struct { - *models.User - IsAdmin bool `json:"isAdmin"` -} - -func getUserData(user *models.User) *UserContext { - isAdmin := false - - if user.Role == models.UserRoleAdmin { - isAdmin = true - } - - return &UserContext{ - User: user, - IsAdmin: isAdmin, - } -} +type UserContext = middleware.AuthUser func GetUser(c *fiber.Ctx) *UserContext { - user := c.Locals("user").(*models.User) - return getUserData(user) + user, _ := c.Locals("user").(*UserContext) + return user } func GetUserWs(c *websocket.Conn) *UserContext { - user := c.Locals("user").(*models.User) - return getUserData(user) + user, _ := c.Locals("user").(*UserContext) + return user }