feat: team access control

This commit is contained in:
Khairul Hidayat 2024-11-12 17:17:10 +00:00
parent f5250d5361
commit 2d4c81e15d
31 changed files with 410 additions and 161 deletions

View File

@ -2,11 +2,17 @@ import { GestureHandlerRootView } from "react-native-gesture-handler";
import { Drawer } from "expo-router/drawer";
import React from "react";
import { useMedia } from "tamagui";
import DrawerContent from "@/components/containers/drawer";
import DrawerContent, {
DrawerNavigationOptions,
} from "@/components/containers/drawer";
import Icons from "@/components/ui/icons";
import { useUser } from "@/hooks/useUser";
import { useTeamId } from "@/stores/auth";
export default function Layout() {
const media = useMedia();
const teamId = useTeamId();
const user = useUser();
return (
<GestureHandlerRootView style={{ flex: 1 }}>
@ -29,12 +35,15 @@ export default function Layout() {
/>
<Drawer.Screen
name="keychains"
options={{
options={
{
title: "Keychains",
hidden: teamId && !user?.teamCanWrite(teamId),
drawerIcon: ({ size, color }) => (
<Icons name="key" size={size} color={color} />
),
}}
} as DrawerNavigationOptions
}
/>
<Drawer.Screen
name="terminal"

View File

@ -11,8 +11,8 @@ import { QueryClientProvider } from "@tanstack/react-query";
import { router, usePathname, useRootNavigationState } from "expo-router";
import { useAuthStore } from "@/stores/auth";
import { PortalProvider } from "tamagui";
import { queryClient } from "@/lib/api";
import { useServer } from "@/stores/app";
import queryClient from "@/lib/queryClient";
type Props = PropsWithChildren;

View File

@ -2,6 +2,7 @@ import React from "react";
import {
DrawerContentComponentProps,
DrawerContentScrollView,
DrawerNavigationOptions as NavProps,
} from "@react-navigation/drawer";
import { Button, View } from "tamagui";
import {
@ -13,6 +14,10 @@ import { Link } from "expo-router";
import ThemeSwitcher from "./theme-switcher";
import UserMenuButton from "./user-menu-button";
export type DrawerNavigationOptions = NavProps & {
hidden?: boolean | null;
};
const Drawer = (props: DrawerContentComponentProps) => {
return (
<>
@ -61,7 +66,12 @@ const DrawerItemList = ({
}
};
const { title, drawerLabel, drawerIcon } = descriptors[route.key].options;
const { title, drawerLabel, drawerIcon, hidden } = descriptors[route.key]
.options as DrawerNavigationOptions;
if (hidden) {
return null;
}
return (
<Link key={route.key} href={buildHref(route.name, route.params) as never}>

View File

@ -11,11 +11,13 @@ import {
} from "tamagui";
import MenuButton from "../ui/menu-button";
import Icons from "../ui/icons";
import { logout } from "@/stores/auth";
import { logout, setTeam, useTeamId } from "@/stores/auth";
import { useUser } from "@/hooks/useUser";
const UserMenuButton = () => {
const user = useUser();
const teamId = useTeamId();
const team = user?.teams?.find((t: any) => t.id === teamId);
return (
<MenuButton
@ -35,7 +37,7 @@ const UserMenuButton = () => {
<View flex={1} style={{ textAlign: "left" }}>
<Text numberOfLines={1}>{user?.name}</Text>
<Text numberOfLines={1} fontWeight="600" mt="$1.5">
Personal
{team ? `${team.icon} ${team.name}` : "Personal"}
</Text>
</View>
<Icons name="chevron-down" size={16} />
@ -61,6 +63,7 @@ const UserMenuButton = () => {
const TeamsMenu = () => {
const media = useMedia();
const user = useUser();
const teamId = useTeamId();
const teams = user?.teams || [];
return (
@ -73,20 +76,30 @@ const TeamsMenu = () => {
<ListItem
hoverTheme
pressTheme
onPress={() => console.log("logout")}
icon={<Icons name="account-group" size={16} />}
title="Teams"
iconAfter={<Icons name="chevron-right" size={16} />}
/>
}
>
{teamId != null && (
<MenuButton.Item
icon={<Icons name="account" size={16} />}
title="Personal"
onPress={() => setTeam(null)}
/>
)}
{teams.map((team: any) => (
<MenuButton.Item icon={<Text>{team.icon}</Text>} title={team.name} />
<MenuButton.Item
key={team.id}
icon={<Text>{team.icon}</Text>}
iconAfter={
teamId === team.id ? <Icons name="check" size={16} /> : undefined
}
title={team.name}
onPress={() => setTeam(team.id)}
/>
))}
{teams.length > 0 && <Separator width="100%" />}

View File

@ -6,5 +6,32 @@ export const useUser = () => {
queryKey: ["auth", "user"],
queryFn: authRepo.getUser,
});
return user;
if (!user) {
return null;
}
function getTeamRole(teamId?: string | null) {
if (!user.teams?.length) {
return false;
}
const team = user.teams.find((i: any) => i.id === teamId);
return team?.role;
}
function isInTeam(teamId?: string | null) {
return getTeamRole(teamId) != null;
}
function teamCanWrite(teamId?: string | null) {
const role = getTeamRole(teamId);
return ["admin", "owner"].includes(role);
}
return {
...user,
getTeamRole,
isInTeam,
teamCanWrite,
};
};

View File

@ -1,6 +1,5 @@
import { getCurrentServer } from "@/stores/app";
import authStore from "@/stores/auth";
import { QueryClient } from "@tanstack/react-query";
import { ofetch } from "ofetch";
const api = ofetch.create({
@ -13,9 +12,13 @@ const api = ofetch.create({
// set server url
config.options.baseURL = server.url;
const authToken = authStore.getState().token;
if (authToken) {
config.options.headers.set("Authorization", `Bearer ${authToken}`);
const { token, teamId } = authStore.getState();
if (token) {
config.options.headers.set("Authorization", `Bearer ${token}`);
}
if (teamId) {
config.options.headers.set("X-Team-Id", teamId);
}
},
onResponseError: (error) => {
@ -31,6 +34,4 @@ const api = ofetch.create({
},
});
export const queryClient = new QueryClient();
export default api;

View File

@ -0,0 +1,5 @@
import { QueryClient } from "@tanstack/react-query";
const queryClient = new QueryClient();
export default queryClient;

View File

@ -47,6 +47,7 @@ export default function LoginPage() {
marginHorizontal: "auto",
},
title: "Login",
headerTitle: "",
headerRight: () => (
<ThemeSwitcher bg="$colorTransparent" $gtSm={{ mr: "$3" }} />
),

View File

@ -8,6 +8,7 @@ import { useTermSession } from "@/stores/terminal-sessions";
import { hostFormModal } from "./form";
import GridView from "@/components/ui/grid-view";
import HostItem from "./host-item";
import { useHosts } from "../hooks/query";
type HostsListProps = {
allowEdit?: boolean;
@ -18,11 +19,7 @@ const HostList = ({ allowEdit = true }: HostsListProps) => {
const navigation = useNavigation();
const [search, setSearch] = useState("");
const hosts = useQuery({
queryKey: ["hosts"],
queryFn: () => api("/hosts"),
select: (i) => i.rows,
});
const hosts = useHosts();
const hostsList = useMemo(() => {
let items = hosts.data || [];

View File

@ -1,8 +1,19 @@
import { useMutation, useQuery } from "@tanstack/react-query";
import { FormSchema } from "../schema/form";
import api, { queryClient } from "@/lib/api";
import api from "@/lib/api";
import { useMemo } from "react";
import { useKeychains } from "@/pages/keychains/hooks/query";
import queryClient from "@/lib/queryClient";
import { useTeamId } from "@/stores/auth";
export const useHosts = () => {
const teamId = useTeamId();
return useQuery({
queryKey: ["hosts", teamId],
queryFn: () => api("/hosts", { params: { teamId } }),
select: (i) => i.rows,
});
};
export const useKeychainsOptions = () => {
const keys = useKeychains();
@ -20,8 +31,11 @@ export const useKeychainsOptions = () => {
};
export const useSaveHost = () => {
const teamId = useTeamId();
return useMutation({
mutationFn: async (body: FormSchema) => {
mutationFn: async (payload: FormSchema) => {
const body = { teamId, ...payload };
return body.id
? api(`/hosts/${body.id}`, { method: "PUT", body })
: api(`/hosts`, { method: "POST", body });

View File

@ -1,8 +1,13 @@
import api, { queryClient } from "@/lib/api";
import api from "@/lib/api";
import { useMutation, useQuery } from "@tanstack/react-query";
import { FormSchema } from "../schema/form";
import queryClient from "@/lib/queryClient";
import { useTeamId } from "@/stores/auth";
export const useKeychains = (params?: any) => {
const teamId = useTeamId();
const query = { teamId, ...params };
export const useKeychains = (query?: any) => {
return useQuery({
queryKey: ["keychains", query],
queryFn: () => api("/keychains", { query }),
@ -11,8 +16,11 @@ export const useKeychains = (query?: any) => {
};
export const useSaveKeychain = () => {
const teamId = useTeamId();
return useMutation({
mutationFn: async (body: FormSchema) => {
mutationFn: async (payload: FormSchema) => {
const body = { teamId, ...payload };
return body.id
? api(`/keychains/${body.id}`, { method: "PUT", body })
: api(`/keychains`, { method: "POST", body });

View File

@ -2,15 +2,18 @@ import { createStore, useStore } from "zustand";
import { persist, createJSONStorage } from "zustand/middleware";
import AsyncStorage from "@react-native-async-storage/async-storage";
import termSessionStore from "./terminal-sessions";
import queryClient from "@/lib/queryClient";
type AuthStore = {
token?: string | null;
token: string | null;
teamId: string | null;
};
const authStore = createStore(
persist<AuthStore>(
() => ({
token: null,
teamId: null,
}),
{
name: "vaulterm:auth",
@ -24,9 +27,18 @@ export const useAuthStore = () => {
return { ...state, isLoggedIn: state.token != null };
};
export const setTeam = (teamId: string | null) => {
authStore.setState({ teamId });
queryClient.invalidateQueries();
};
export const logout = () => {
authStore.setState({ token: null });
authStore.setState({ token: null, teamId: null });
termSessionStore.setState({ sessions: [], curSession: 0 });
};
export const useTeamId = () => {
return useStore(authStore, (i) => i.teamId);
};
export default authStore;

View File

@ -54,7 +54,21 @@ func login(c *fiber.Ctx) error {
func getUser(c *fiber.Ctx) error {
user := utils.GetUser(c)
return c.JSON(user)
teams := []TeamWithRole{}
for _, item := range user.Teams {
teams = append(teams, TeamWithRole{
ID: item.TeamID,
Name: item.Team.Name,
Icon: item.Team.Icon,
Role: item.Role,
})
}
return c.JSON(&GetUserResult{
AuthUser: *user,
Teams: teams,
})
}
func logout(c *fiber.Ctx) error {

View File

@ -1,6 +1,20 @@
package auth
import "rul.sh/vaulterm/middleware"
type LoginSchema struct {
Username string `json:"username"`
Password string `json:"password"`
}
type TeamWithRole struct {
ID string `json:"id"`
Name string `json:"name"`
Icon string `json:"icon"`
Role string `json:"role"`
}
type GetUserResult struct {
middleware.AuthUser
Teams []TeamWithRole `json:"teams"`
}

View File

@ -20,8 +20,14 @@ func NewRepository(r *Hosts) *Hosts {
return r
}
func (r *Hosts) GetAll() ([]*models.Host, error) {
query := r.ACL(r.db.Order("id DESC"))
func (r *Hosts) GetAll(opt GetAllOpt) ([]*models.Host, error) {
query := r.db.Order("id DESC")
if opt.TeamID != "" {
query = query.Where("hosts.team_id = ?", opt.TeamID)
} else {
query = query.Where("hosts.owner_id = ? AND hosts.team_id IS NULL", r.User.ID)
}
var rows []*models.Host
ret := query.Find(&rows)
@ -30,10 +36,8 @@ func (r *Hosts) GetAll() ([]*models.Host, error) {
}
func (r *Hosts) Get(id string) (*models.HostDecrypted, error) {
query := r.ACL(r.db)
var host models.Host
ret := query.Joins("Key").Joins("AltKey").Where("hosts.id = ?", id).First(&host)
ret := r.db.Joins("Key").Joins("AltKey").Where("hosts.id = ?", id).First(&host)
if ret.Error != nil {
return nil, ret.Error
}
@ -48,13 +52,12 @@ func (r *Hosts) Get(id string) (*models.HostDecrypted, error) {
func (r *Hosts) Exists(id string) (bool, error) {
var count int64
ret := r.ACL(r.db.Model(&models.Host{}).Where("id = ?", id)).Count(&count)
ret := r.db.Model(&models.Host{}).Where("id = ?", id).Count(&count)
return count > 0, ret.Error
}
func (r *Hosts) Delete(id string) error {
query := r.ACL(r.db)
return query.Delete(&models.Host{Model: models.Model{ID: id}}).Error
return r.db.Delete(&models.Host{Model: models.Model{ID: id}}).Error
}
func (r *Hosts) Create(item *models.Host) error {
@ -62,15 +65,5 @@ func (r *Hosts) Create(item *models.Host) error {
}
func (r *Hosts) Update(id string, item *models.Host) error {
query := r.ACL(r.db.Where("id = ?", id))
return query.Updates(item).Error
}
func (r *Hosts) ACL(query *gorm.DB) *gorm.DB {
if r.User.IsAdmin {
return query
}
return query.Where("hosts.owner_id = ?", r.User.ID)
return r.db.Where("id = ?", id).Updates(item).Error
}

View File

@ -1,6 +1,7 @@
package hosts
import (
"errors"
"fmt"
"net/http"
@ -19,10 +20,15 @@ func Router(app fiber.Router) {
}
func getAll(c *fiber.Ctx) error {
teamId := c.Query("teamId")
user := utils.GetUser(c)
repo := NewRepository(&Hosts{User: user})
rows, err := repo.GetAll()
if teamId != "" && !user.IsInTeam(&teamId) {
return utils.ResponseError(c, errors.New("no access"), 403)
}
rows, err := repo.GetAll(GetAllOpt{TeamID: teamId})
if err != nil {
return utils.ResponseError(c, err, 500)
}
@ -41,8 +47,13 @@ func create(c *fiber.Ctx) error {
user := utils.GetUser(c)
repo := NewRepository(&Hosts{User: user})
if body.TeamID != nil && !user.TeamCanWrite(body.TeamID) {
return utils.ResponseError(c, errors.New("no access"), 403)
}
item := &models.Host{
OwnerID: user.ID,
OwnerID: &user.ID,
TeamID: body.TeamID,
Type: body.Type,
Label: body.Label,
Host: body.Host,
@ -76,13 +87,17 @@ func update(c *fiber.Ctx) error {
repo := NewRepository(&Hosts{User: user})
id := c.Params("id")
exist, _ := repo.Exists(id)
if !exist {
return utils.ResponseError(c, fmt.Errorf("host %s not found", id), 404)
data, _ := repo.Get(id)
if data == nil {
return utils.ResponseError(c, errors.New("host not found"), 404)
}
if !data.CanWrite(&user.User) || !user.TeamCanWrite(body.TeamID) {
return utils.ResponseError(c, errors.New("no access"), 403)
}
item := &models.Host{
Model: models.Model{ID: id},
TeamID: body.TeamID,
Type: body.Type,
Label: body.Label,
Host: body.Host,
@ -111,9 +126,12 @@ func delete(c *fiber.Ctx) error {
repo := NewRepository(&Hosts{User: user})
id := c.Params("id")
exist, _ := repo.Exists(id)
if !exist {
return utils.ResponseError(c, fmt.Errorf("host %s not found", id), 404)
host, _ := repo.Get(id)
if host == nil {
return utils.ResponseError(c, errors.New("host not found"), 404)
}
if !host.CanWrite(&user.User) {
return utils.ResponseError(c, errors.New("no access"), 403)
}
if err := repo.Delete(id); err != nil {

View File

@ -9,7 +9,12 @@ type CreateHostSchema struct {
Port int `json:"port"`
Metadata datatypes.JSONMap `json:"metadata"`
TeamID *string `json:"teamId"`
ParentID *string `json:"parentId"`
KeyID *string `json:"keyId"`
AltKeyID *string `json:"altKeyId"`
}
type GetAllOpt struct {
TeamID string
}

View File

@ -20,10 +20,16 @@ func NewRepository(r *Keychains) *Keychains {
return r
}
func (r *Keychains) GetAll() ([]*models.Keychain, error) {
var rows []*models.Keychain
query := r.ACL(r.db.Order("created_at DESC"))
func (r *Keychains) GetAll(opt GetAllOpt) ([]*models.Keychain, error) {
query := r.db.Order("created_at DESC")
if opt.TeamID != "" {
query = query.Where("keychains.team_id = ?", opt.TeamID)
} else {
query = query.Where("keychains.owner_id = ? AND keychains.team_id IS NULL", r.User.ID)
}
var rows []*models.Keychain
ret := query.Find(&rows)
return rows, ret.Error
}
@ -34,9 +40,7 @@ func (r *Keychains) Create(item *models.Keychain) error {
func (r *Keychains) Get(id string) (*models.Keychain, error) {
var keychain models.Keychain
query := r.ACL(r.db.Where("id = ?", id))
if err := query.First(&keychain).Error; err != nil {
if err := r.db.Where("id = ?", id).First(&keychain).Error; err != nil {
return nil, err
}
@ -45,8 +49,7 @@ func (r *Keychains) Get(id string) (*models.Keychain, error) {
func (r *Keychains) Exists(id string) (bool, error) {
var count int64
query := r.ACL(r.db.Model(&models.Keychain{}).Where("id = ?", id))
ret := query.Count(&count)
ret := r.db.Model(&models.Keychain{}).Where("id = ?", id).Count(&count)
return count > 0, ret.Error
}
@ -70,14 +73,5 @@ func (r *Keychains) GetDecrypted(id string) (*KeychainDecrypted, error) {
}
func (r *Keychains) Update(id string, item *models.Keychain) error {
query := r.ACL(r.db.Where("id = ?", id))
return query.Updates(item).Error
}
func (r *Keychains) ACL(query *gorm.DB) *gorm.DB {
if r.User.IsAdmin {
return query
}
return query.Where("keychains.owner_id = ?", r.User.ID)
return r.db.Where("id = ?", id).Updates(item).Error
}

View File

@ -1,7 +1,7 @@
package keychains
import (
"fmt"
"errors"
"net/http"
"github.com/gofiber/fiber/v2"
@ -23,17 +23,22 @@ type GetAllResult struct {
}
func getAll(c *fiber.Ctx) error {
teamId := c.Query("teamId")
withData := c.Query("withData")
user := utils.GetUser(c)
repo := NewRepository(&Keychains{User: user})
rows, err := repo.GetAll()
if teamId != "" && !user.IsInTeam(&teamId) {
return utils.ResponseError(c, errors.New("no access"), 403)
}
rows, err := repo.GetAll(GetAllOpt{TeamID: teamId})
if err != nil {
return utils.ResponseError(c, err, 500)
}
if withData != "true" {
if withData != "true" || (teamId != "" && !user.TeamCanWrite(&teamId)) {
return c.JSON(fiber.Map{"rows": rows})
}
@ -67,8 +72,13 @@ func create(c *fiber.Ctx) error {
user := utils.GetUser(c)
repo := NewRepository(&Keychains{User: user})
if body.TeamID != nil && !user.TeamCanWrite(body.TeamID) {
return utils.ResponseError(c, errors.New("no access"), 403)
}
item := &models.Keychain{
OwnerID: user.ID,
OwnerID: &user.ID,
TeamID: body.TeamID,
Type: body.Type,
Label: body.Label,
}
@ -94,13 +104,16 @@ func update(c *fiber.Ctx) error {
repo := NewRepository(&Keychains{User: user})
id := c.Params("id")
exist, _ := repo.Exists(id)
if !exist {
return utils.ResponseError(c, fmt.Errorf("key %s not found", id), 404)
data, _ := repo.Get(id)
if data == nil {
return utils.ResponseError(c, errors.New("key not found"), 404)
}
if !data.CanWrite(&user.User) || !user.TeamCanWrite(body.TeamID) {
return utils.ResponseError(c, errors.New("no access"), 403)
}
item := &models.Keychain{
TeamID: body.TeamID,
Type: body.Type,
Label: body.Label,
}

View File

@ -1,7 +1,12 @@
package keychains
type CreateKeychainSchema struct {
TeamID *string `json:"teamId"`
Type string `json:"type"`
Label string `json:"label"`
Data interface{} `json:"data"`
}
type GetAllOpt struct {
TeamID string
}

View File

@ -0,0 +1,50 @@
package teams
import (
"gorm.io/gorm"
"rul.sh/vaulterm/db"
"rul.sh/vaulterm/models"
"rul.sh/vaulterm/utils"
)
type Teams struct {
db *gorm.DB
User *utils.UserContext
}
func NewRepository(r *Teams) *Teams {
if r == nil {
r = &Teams{}
}
r.db = db.Get()
return r
}
func (r *Teams) GetAll() ([]*models.Team, error) {
var rows []*models.Team
ret := r.db.Order("created_at DESC").Find(&rows)
return rows, ret.Error
}
func (r *Teams) Create(data *models.Team) error {
return r.db.Create(data).Error
}
func (r *Teams) Get(id string) (*models.Team, error) {
var data models.Team
if err := r.db.Where("id = ?", id).First(&data).Error; err != nil {
return nil, err
}
return &data, nil
}
func (r *Teams) Exists(id string) (bool, error) {
var count int64
ret := r.db.Model(&models.Team{}).Where("id = ?", id).Count(&count)
return count > 0, ret.Error
}
func (r *Teams) Update(id string, item *models.Team) error {
return r.db.Where("id = ?", id).Updates(item).Error
}

View File

@ -17,8 +17,8 @@ func HandleTerm(c *websocket.Conn) {
hostRepo := hosts.NewRepository(&hosts.Hosts{User: user})
data, err := hostRepo.Get(hostId)
if data == nil {
log.Printf("Cannot find host! Error: %s\n", err.Error())
if data == nil || !data.HasAccess(&user.User) {
log.Printf("Cannot find host! %v\n", err)
c.WriteMessage(websocket.TextMessage, []byte("Host not found"))
return
}

View File

@ -45,7 +45,6 @@ func Init() {
// Migrate the schema
db.AutoMigrate(Models...)
InitModels(db)
runSeeders(db)
}

View File

@ -1,9 +1,6 @@
package db
import (
"log"
"gorm.io/gorm"
"rul.sh/vaulterm/models"
)
@ -15,9 +12,3 @@ var Models = []interface{}{
&models.Team{},
&models.TeamMembers{},
}
func InitModels(db *gorm.DB) {
if err := db.SetupJoinTable(&models.Team{}, "Members", &models.TeamMembers{}); err != nil {
log.Fatal(err)
}
}

View File

@ -66,9 +66,9 @@ func seedUsers(tx *gorm.DB) error {
}
teamMembers := []models.TeamMembers{
{TeamID: teams[0].ID, UserID: userList[0].ID, Role: "owner"},
{TeamID: teams[0].ID, UserID: userList[1].ID, Role: "admin"},
{TeamID: teams[0].ID, UserID: userList[2].ID, Role: "user"},
{TeamID: teams[0].ID, UserID: userList[0].ID, Role: models.TeamRoleOwner},
{TeamID: teams[0].ID, UserID: userList[1].ID, Role: models.TeamRoleAdmin},
{TeamID: teams[0].ID, UserID: userList[2].ID, Role: models.TeamRoleMember},
}
if res := tx.Create(&teamMembers); res.Error != nil {

View File

@ -4,7 +4,6 @@ import (
"strings"
"github.com/gofiber/fiber/v2"
"gorm.io/gorm"
"rul.sh/vaulterm/db"
"rul.sh/vaulterm/models"
)
@ -23,25 +22,35 @@ func Auth(c *fiber.Ctx) error {
session, _ := GetUserSession(sessionId)
if session != nil && session.User.ID != "" {
c.Locals("user", &session.User)
if session != nil && session.ID != "" {
c.Locals("user", session)
c.Locals("sessionId", sessionId)
}
return c.Next()
}
func GetUserSession(sessionId string) (*models.UserSession, error) {
var session models.UserSession
type AuthUser struct {
models.User
SessionID string `json:"sessionId" gorm:"column:session_id"`
}
func GetUserSession(sessionId string) (*AuthUser, error) {
var session AuthUser
res := db.Get().
Joins("User").
Preload("User.Teams", func(db *gorm.DB) *gorm.DB {
return db.Select("id", "name", "icon")
}).
Model(&models.User{}).
Joins("JOIN user_sessions ON user_sessions.user_id = users.id").
Preload("Teams.Team").
Select("users.*, user_sessions.id AS session_id").
Where("user_sessions.id = ?", sessionId).
First(&session)
return &session, res.Error
if res.Error != nil || session.User.ID == "" {
return nil, res.Error
}
return &session, nil
}
func Protected() func(c *fiber.Ctx) error {

View File

@ -1,6 +1,8 @@
package models
import "gorm.io/datatypes"
import (
"gorm.io/datatypes"
)
const (
HostTypeSSH = "ssh"
@ -14,8 +16,10 @@ const (
type Host struct {
Model
OwnerID string `json:"userId" gorm:"index:hosts_owner_id_idx;type:varchar(26)"`
Owner User `json:"user" gorm:"foreignKey:OwnerID"`
OwnerID *string `json:"userId" gorm:"type:varchar(26)"`
Owner *User `json:"user" gorm:"foreignKey:OwnerID"`
TeamID *string `json:"teamId" gorm:"type:varchar(26)"`
Team *Team `json:"team" gorm:"foreignKey:TeamID"`
Type string `json:"type" gorm:"not null;index:hosts_type_idx;type:varchar(16)"`
Label string `json:"label"`
@ -24,11 +28,11 @@ type Host struct {
OS string `json:"os" gorm:"type:varchar(32)"`
Metadata datatypes.JSONMap `json:"metadata"`
ParentID *string `json:"parentId" gorm:"index:hosts_parent_id_idx;type:varchar(26)"`
ParentID *string `json:"parentId" gorm:"type:varchar(26)"`
Parent *Host `json:"parent" gorm:"foreignKey:ParentID"`
KeyID *string `json:"keyId" gorm:"index:hosts_key_id_idx"`
KeyID *string `json:"keyId" gorm:"type:varchar(26)"`
Key Keychain `json:"key" gorm:"foreignKey:KeyID"`
AltKeyID *string `json:"altKeyId" gorm:"index:hosts_altkey_id_idx"`
AltKeyID *string `json:"altKeyId" gorm:"type:varchar(26)"`
AltKey Keychain `json:"altKey" gorm:"foreignKey:AltKeyID"`
Timestamps
@ -58,13 +62,17 @@ func (h *Host) DecryptKeys() (*HostDecrypted, error) {
return res, nil
}
type HostHasAccessOptions struct {
UserID string
}
func (h *Host) HasAccess(o HostHasAccessOptions) bool {
if o.UserID == h.OwnerID {
func (h *Host) HasAccess(user *User) bool {
if user.IsAdmin() {
return true
}
return false
return *h.OwnerID == user.ID || user.IsInTeam(h.TeamID)
}
func (h *Host) CanWrite(user *User) bool {
if user.IsAdmin() {
return true
}
teamRole := user.GetTeamRole(h.TeamID)
return *h.OwnerID == user.ID || teamRole == TeamRoleOwner || teamRole == TeamRoleAdmin
}

View File

@ -16,8 +16,10 @@ const (
type Keychain struct {
Model
OwnerID string `json:"userId" gorm:"index:hosts_owner_id_idx;type:varchar(26)"`
Owner User `json:"user" gorm:"foreignKey:OwnerID"`
OwnerID *string `json:"userId" gorm:"type:varchar(26)"`
Owner *User `json:"user" gorm:"foreignKey:OwnerID"`
TeamID *string `json:"teamId" gorm:"type:varchar(26)"`
Team *Team `json:"team" gorm:"foreignKey:TeamID"`
Label string `json:"label"`
Type string `json:"type" gorm:"not null;index:keychains_type_idx;type:varchar(12)"`
@ -55,3 +57,18 @@ func (k *Keychain) DecryptData(data interface{}) error {
return nil
}
func (k *Keychain) HasAccess(user *User) bool {
if user.IsAdmin() {
return true
}
return *k.OwnerID == user.ID || user.IsInTeam(k.TeamID)
}
func (k *Keychain) CanWrite(user *User) bool {
if user.IsAdmin() {
return true
}
teamRole := user.GetTeamRole(k.TeamID)
return *k.OwnerID == user.ID || teamRole == TeamRoleOwner || teamRole == TeamRoleAdmin
}

View File

@ -2,12 +2,18 @@ package models
import "time"
const (
TeamRoleOwner = "owner"
TeamRoleAdmin = "admin"
TeamRoleMember = "member"
)
type Team struct {
Model
Name string `json:"name" gorm:"type:varchar(32)"`
Icon string `json:"icon" gorm:"type:varchar(2)"`
Members []*User `json:"members" gorm:"many2many:team_members"`
Members []*TeamMembers `json:"members" gorm:"foreignKey:TeamID"`
Timestamps
SoftDeletes

View File

@ -1,5 +1,7 @@
package models
import "slices"
const (
UserRoleUser = "user"
UserRoleAdmin = "admin"
@ -14,7 +16,7 @@ type User struct {
Email string `json:"email" gorm:"unique"`
Role string `json:"role" gorm:"default:user;not null;index:users_role_idx;type:varchar(8)"`
Teams []*Team `json:"teams" gorm:"many2many:team_members"`
Teams []*TeamMembers `json:"teams" gorm:"foreignKey:UserID"`
Timestamps
SoftDeletes
@ -28,3 +30,33 @@ type UserSession struct {
Timestamps
SoftDeletes
}
func (u *User) IsAdmin() bool {
return u.Role == UserRoleAdmin
}
func (u *User) GetTeamRole(teamId *string) string {
if u.IsAdmin() {
return TeamRoleAdmin
}
if teamId == nil {
return ""
}
idx := slices.IndexFunc(u.Teams, func(tm *TeamMembers) bool {
return tm.TeamID == *teamId
})
if idx == -1 {
return ""
}
return u.Teams[idx].Role
}
func (u *User) IsInTeam(teamId *string) bool {
role := u.GetTeamRole(teamId)
return role != ""
}
func (u *User) TeamCanWrite(teamId *string) bool {
role := u.GetTeamRole(teamId)
return role == TeamRoleAdmin || role == TeamRoleOwner
}

View File

@ -3,33 +3,17 @@ package utils
import (
"github.com/gofiber/contrib/websocket"
"github.com/gofiber/fiber/v2"
"rul.sh/vaulterm/models"
"rul.sh/vaulterm/middleware"
)
type UserContext struct {
*models.User
IsAdmin bool `json:"isAdmin"`
}
func getUserData(user *models.User) *UserContext {
isAdmin := false
if user.Role == models.UserRoleAdmin {
isAdmin = true
}
return &UserContext{
User: user,
IsAdmin: isAdmin,
}
}
type UserContext = middleware.AuthUser
func GetUser(c *fiber.Ctx) *UserContext {
user := c.Locals("user").(*models.User)
return getUserData(user)
user, _ := c.Locals("user").(*UserContext)
return user
}
func GetUserWs(c *websocket.Conn) *UserContext {
user := c.Locals("user").(*models.User)
return getUserData(user)
user, _ := c.Locals("user").(*UserContext)
return user
}