feat: team access control

This commit is contained in:
Khairul Hidayat 2024-11-12 17:17:10 +00:00
parent f5250d5361
commit 2d4c81e15d
31 changed files with 410 additions and 161 deletions

View File

@ -2,11 +2,17 @@ import { GestureHandlerRootView } from "react-native-gesture-handler";
import { Drawer } from "expo-router/drawer"; import { Drawer } from "expo-router/drawer";
import React from "react"; import React from "react";
import { useMedia } from "tamagui"; import { useMedia } from "tamagui";
import DrawerContent from "@/components/containers/drawer"; import DrawerContent, {
DrawerNavigationOptions,
} from "@/components/containers/drawer";
import Icons from "@/components/ui/icons"; import Icons from "@/components/ui/icons";
import { useUser } from "@/hooks/useUser";
import { useTeamId } from "@/stores/auth";
export default function Layout() { export default function Layout() {
const media = useMedia(); const media = useMedia();
const teamId = useTeamId();
const user = useUser();
return ( return (
<GestureHandlerRootView style={{ flex: 1 }}> <GestureHandlerRootView style={{ flex: 1 }}>
@ -29,12 +35,15 @@ export default function Layout() {
/> />
<Drawer.Screen <Drawer.Screen
name="keychains" name="keychains"
options={{ options={
title: "Keychains", {
drawerIcon: ({ size, color }) => ( title: "Keychains",
<Icons name="key" size={size} color={color} /> hidden: teamId && !user?.teamCanWrite(teamId),
), drawerIcon: ({ size, color }) => (
}} <Icons name="key" size={size} color={color} />
),
} as DrawerNavigationOptions
}
/> />
<Drawer.Screen <Drawer.Screen
name="terminal" name="terminal"

View File

@ -11,8 +11,8 @@ import { QueryClientProvider } from "@tanstack/react-query";
import { router, usePathname, useRootNavigationState } from "expo-router"; import { router, usePathname, useRootNavigationState } from "expo-router";
import { useAuthStore } from "@/stores/auth"; import { useAuthStore } from "@/stores/auth";
import { PortalProvider } from "tamagui"; import { PortalProvider } from "tamagui";
import { queryClient } from "@/lib/api";
import { useServer } from "@/stores/app"; import { useServer } from "@/stores/app";
import queryClient from "@/lib/queryClient";
type Props = PropsWithChildren; type Props = PropsWithChildren;

View File

@ -2,6 +2,7 @@ import React from "react";
import { import {
DrawerContentComponentProps, DrawerContentComponentProps,
DrawerContentScrollView, DrawerContentScrollView,
DrawerNavigationOptions as NavProps,
} from "@react-navigation/drawer"; } from "@react-navigation/drawer";
import { Button, View } from "tamagui"; import { Button, View } from "tamagui";
import { import {
@ -13,6 +14,10 @@ import { Link } from "expo-router";
import ThemeSwitcher from "./theme-switcher"; import ThemeSwitcher from "./theme-switcher";
import UserMenuButton from "./user-menu-button"; import UserMenuButton from "./user-menu-button";
export type DrawerNavigationOptions = NavProps & {
hidden?: boolean | null;
};
const Drawer = (props: DrawerContentComponentProps) => { const Drawer = (props: DrawerContentComponentProps) => {
return ( return (
<> <>
@ -61,7 +66,12 @@ const DrawerItemList = ({
} }
}; };
const { title, drawerLabel, drawerIcon } = descriptors[route.key].options; const { title, drawerLabel, drawerIcon, hidden } = descriptors[route.key]
.options as DrawerNavigationOptions;
if (hidden) {
return null;
}
return ( return (
<Link key={route.key} href={buildHref(route.name, route.params) as never}> <Link key={route.key} href={buildHref(route.name, route.params) as never}>

View File

@ -11,11 +11,13 @@ import {
} from "tamagui"; } from "tamagui";
import MenuButton from "../ui/menu-button"; import MenuButton from "../ui/menu-button";
import Icons from "../ui/icons"; import Icons from "../ui/icons";
import { logout } from "@/stores/auth"; import { logout, setTeam, useTeamId } from "@/stores/auth";
import { useUser } from "@/hooks/useUser"; import { useUser } from "@/hooks/useUser";
const UserMenuButton = () => { const UserMenuButton = () => {
const user = useUser(); const user = useUser();
const teamId = useTeamId();
const team = user?.teams?.find((t: any) => t.id === teamId);
return ( return (
<MenuButton <MenuButton
@ -35,7 +37,7 @@ const UserMenuButton = () => {
<View flex={1} style={{ textAlign: "left" }}> <View flex={1} style={{ textAlign: "left" }}>
<Text numberOfLines={1}>{user?.name}</Text> <Text numberOfLines={1}>{user?.name}</Text>
<Text numberOfLines={1} fontWeight="600" mt="$1.5"> <Text numberOfLines={1} fontWeight="600" mt="$1.5">
Personal {team ? `${team.icon} ${team.name}` : "Personal"}
</Text> </Text>
</View> </View>
<Icons name="chevron-down" size={16} /> <Icons name="chevron-down" size={16} />
@ -61,6 +63,7 @@ const UserMenuButton = () => {
const TeamsMenu = () => { const TeamsMenu = () => {
const media = useMedia(); const media = useMedia();
const user = useUser(); const user = useUser();
const teamId = useTeamId();
const teams = user?.teams || []; const teams = user?.teams || [];
return ( return (
@ -73,20 +76,30 @@ const TeamsMenu = () => {
<ListItem <ListItem
hoverTheme hoverTheme
pressTheme pressTheme
onPress={() => console.log("logout")}
icon={<Icons name="account-group" size={16} />} icon={<Icons name="account-group" size={16} />}
title="Teams" title="Teams"
iconAfter={<Icons name="chevron-right" size={16} />} iconAfter={<Icons name="chevron-right" size={16} />}
/> />
} }
> >
<MenuButton.Item {teamId != null && (
icon={<Icons name="account" size={16} />} <MenuButton.Item
title="Personal" icon={<Icons name="account" size={16} />}
/> title="Personal"
onPress={() => setTeam(null)}
/>
)}
{teams.map((team: any) => ( {teams.map((team: any) => (
<MenuButton.Item icon={<Text>{team.icon}</Text>} title={team.name} /> <MenuButton.Item
key={team.id}
icon={<Text>{team.icon}</Text>}
iconAfter={
teamId === team.id ? <Icons name="check" size={16} /> : undefined
}
title={team.name}
onPress={() => setTeam(team.id)}
/>
))} ))}
{teams.length > 0 && <Separator width="100%" />} {teams.length > 0 && <Separator width="100%" />}

View File

@ -6,5 +6,32 @@ export const useUser = () => {
queryKey: ["auth", "user"], queryKey: ["auth", "user"],
queryFn: authRepo.getUser, queryFn: authRepo.getUser,
}); });
return user;
if (!user) {
return null;
}
function getTeamRole(teamId?: string | null) {
if (!user.teams?.length) {
return false;
}
const team = user.teams.find((i: any) => i.id === teamId);
return team?.role;
}
function isInTeam(teamId?: string | null) {
return getTeamRole(teamId) != null;
}
function teamCanWrite(teamId?: string | null) {
const role = getTeamRole(teamId);
return ["admin", "owner"].includes(role);
}
return {
...user,
getTeamRole,
isInTeam,
teamCanWrite,
};
}; };

View File

@ -1,6 +1,5 @@
import { getCurrentServer } from "@/stores/app"; import { getCurrentServer } from "@/stores/app";
import authStore from "@/stores/auth"; import authStore from "@/stores/auth";
import { QueryClient } from "@tanstack/react-query";
import { ofetch } from "ofetch"; import { ofetch } from "ofetch";
const api = ofetch.create({ const api = ofetch.create({
@ -13,9 +12,13 @@ const api = ofetch.create({
// set server url // set server url
config.options.baseURL = server.url; config.options.baseURL = server.url;
const authToken = authStore.getState().token; const { token, teamId } = authStore.getState();
if (authToken) {
config.options.headers.set("Authorization", `Bearer ${authToken}`); if (token) {
config.options.headers.set("Authorization", `Bearer ${token}`);
}
if (teamId) {
config.options.headers.set("X-Team-Id", teamId);
} }
}, },
onResponseError: (error) => { onResponseError: (error) => {
@ -31,6 +34,4 @@ const api = ofetch.create({
}, },
}); });
export const queryClient = new QueryClient();
export default api; export default api;

View File

@ -0,0 +1,5 @@
import { QueryClient } from "@tanstack/react-query";
const queryClient = new QueryClient();
export default queryClient;

View File

@ -47,6 +47,7 @@ export default function LoginPage() {
marginHorizontal: "auto", marginHorizontal: "auto",
}, },
title: "Login", title: "Login",
headerTitle: "",
headerRight: () => ( headerRight: () => (
<ThemeSwitcher bg="$colorTransparent" $gtSm={{ mr: "$3" }} /> <ThemeSwitcher bg="$colorTransparent" $gtSm={{ mr: "$3" }} />
), ),

View File

@ -8,6 +8,7 @@ import { useTermSession } from "@/stores/terminal-sessions";
import { hostFormModal } from "./form"; import { hostFormModal } from "./form";
import GridView from "@/components/ui/grid-view"; import GridView from "@/components/ui/grid-view";
import HostItem from "./host-item"; import HostItem from "./host-item";
import { useHosts } from "../hooks/query";
type HostsListProps = { type HostsListProps = {
allowEdit?: boolean; allowEdit?: boolean;
@ -18,11 +19,7 @@ const HostList = ({ allowEdit = true }: HostsListProps) => {
const navigation = useNavigation(); const navigation = useNavigation();
const [search, setSearch] = useState(""); const [search, setSearch] = useState("");
const hosts = useQuery({ const hosts = useHosts();
queryKey: ["hosts"],
queryFn: () => api("/hosts"),
select: (i) => i.rows,
});
const hostsList = useMemo(() => { const hostsList = useMemo(() => {
let items = hosts.data || []; let items = hosts.data || [];

View File

@ -1,8 +1,19 @@
import { useMutation, useQuery } from "@tanstack/react-query"; import { useMutation, useQuery } from "@tanstack/react-query";
import { FormSchema } from "../schema/form"; import { FormSchema } from "../schema/form";
import api, { queryClient } from "@/lib/api"; import api from "@/lib/api";
import { useMemo } from "react"; import { useMemo } from "react";
import { useKeychains } from "@/pages/keychains/hooks/query"; import { useKeychains } from "@/pages/keychains/hooks/query";
import queryClient from "@/lib/queryClient";
import { useTeamId } from "@/stores/auth";
export const useHosts = () => {
const teamId = useTeamId();
return useQuery({
queryKey: ["hosts", teamId],
queryFn: () => api("/hosts", { params: { teamId } }),
select: (i) => i.rows,
});
};
export const useKeychainsOptions = () => { export const useKeychainsOptions = () => {
const keys = useKeychains(); const keys = useKeychains();
@ -20,8 +31,11 @@ export const useKeychainsOptions = () => {
}; };
export const useSaveHost = () => { export const useSaveHost = () => {
const teamId = useTeamId();
return useMutation({ return useMutation({
mutationFn: async (body: FormSchema) => { mutationFn: async (payload: FormSchema) => {
const body = { teamId, ...payload };
return body.id return body.id
? api(`/hosts/${body.id}`, { method: "PUT", body }) ? api(`/hosts/${body.id}`, { method: "PUT", body })
: api(`/hosts`, { method: "POST", body }); : api(`/hosts`, { method: "POST", body });

View File

@ -1,8 +1,13 @@
import api, { queryClient } from "@/lib/api"; import api from "@/lib/api";
import { useMutation, useQuery } from "@tanstack/react-query"; import { useMutation, useQuery } from "@tanstack/react-query";
import { FormSchema } from "../schema/form"; import { FormSchema } from "../schema/form";
import queryClient from "@/lib/queryClient";
import { useTeamId } from "@/stores/auth";
export const useKeychains = (params?: any) => {
const teamId = useTeamId();
const query = { teamId, ...params };
export const useKeychains = (query?: any) => {
return useQuery({ return useQuery({
queryKey: ["keychains", query], queryKey: ["keychains", query],
queryFn: () => api("/keychains", { query }), queryFn: () => api("/keychains", { query }),
@ -11,8 +16,11 @@ export const useKeychains = (query?: any) => {
}; };
export const useSaveKeychain = () => { export const useSaveKeychain = () => {
const teamId = useTeamId();
return useMutation({ return useMutation({
mutationFn: async (body: FormSchema) => { mutationFn: async (payload: FormSchema) => {
const body = { teamId, ...payload };
return body.id return body.id
? api(`/keychains/${body.id}`, { method: "PUT", body }) ? api(`/keychains/${body.id}`, { method: "PUT", body })
: api(`/keychains`, { method: "POST", body }); : api(`/keychains`, { method: "POST", body });

View File

@ -2,15 +2,18 @@ import { createStore, useStore } from "zustand";
import { persist, createJSONStorage } from "zustand/middleware"; import { persist, createJSONStorage } from "zustand/middleware";
import AsyncStorage from "@react-native-async-storage/async-storage"; import AsyncStorage from "@react-native-async-storage/async-storage";
import termSessionStore from "./terminal-sessions"; import termSessionStore from "./terminal-sessions";
import queryClient from "@/lib/queryClient";
type AuthStore = { type AuthStore = {
token?: string | null; token: string | null;
teamId: string | null;
}; };
const authStore = createStore( const authStore = createStore(
persist<AuthStore>( persist<AuthStore>(
() => ({ () => ({
token: null, token: null,
teamId: null,
}), }),
{ {
name: "vaulterm:auth", name: "vaulterm:auth",
@ -24,9 +27,18 @@ export const useAuthStore = () => {
return { ...state, isLoggedIn: state.token != null }; return { ...state, isLoggedIn: state.token != null };
}; };
export const setTeam = (teamId: string | null) => {
authStore.setState({ teamId });
queryClient.invalidateQueries();
};
export const logout = () => { export const logout = () => {
authStore.setState({ token: null }); authStore.setState({ token: null, teamId: null });
termSessionStore.setState({ sessions: [], curSession: 0 }); termSessionStore.setState({ sessions: [], curSession: 0 });
}; };
export const useTeamId = () => {
return useStore(authStore, (i) => i.teamId);
};
export default authStore; export default authStore;

View File

@ -54,7 +54,21 @@ func login(c *fiber.Ctx) error {
func getUser(c *fiber.Ctx) error { func getUser(c *fiber.Ctx) error {
user := utils.GetUser(c) user := utils.GetUser(c)
return c.JSON(user) teams := []TeamWithRole{}
for _, item := range user.Teams {
teams = append(teams, TeamWithRole{
ID: item.TeamID,
Name: item.Team.Name,
Icon: item.Team.Icon,
Role: item.Role,
})
}
return c.JSON(&GetUserResult{
AuthUser: *user,
Teams: teams,
})
} }
func logout(c *fiber.Ctx) error { func logout(c *fiber.Ctx) error {

View File

@ -1,6 +1,20 @@
package auth package auth
import "rul.sh/vaulterm/middleware"
type LoginSchema struct { type LoginSchema struct {
Username string `json:"username"` Username string `json:"username"`
Password string `json:"password"` Password string `json:"password"`
} }
type TeamWithRole struct {
ID string `json:"id"`
Name string `json:"name"`
Icon string `json:"icon"`
Role string `json:"role"`
}
type GetUserResult struct {
middleware.AuthUser
Teams []TeamWithRole `json:"teams"`
}

View File

@ -20,8 +20,14 @@ func NewRepository(r *Hosts) *Hosts {
return r return r
} }
func (r *Hosts) GetAll() ([]*models.Host, error) { func (r *Hosts) GetAll(opt GetAllOpt) ([]*models.Host, error) {
query := r.ACL(r.db.Order("id DESC")) query := r.db.Order("id DESC")
if opt.TeamID != "" {
query = query.Where("hosts.team_id = ?", opt.TeamID)
} else {
query = query.Where("hosts.owner_id = ? AND hosts.team_id IS NULL", r.User.ID)
}
var rows []*models.Host var rows []*models.Host
ret := query.Find(&rows) ret := query.Find(&rows)
@ -30,10 +36,8 @@ func (r *Hosts) GetAll() ([]*models.Host, error) {
} }
func (r *Hosts) Get(id string) (*models.HostDecrypted, error) { func (r *Hosts) Get(id string) (*models.HostDecrypted, error) {
query := r.ACL(r.db)
var host models.Host var host models.Host
ret := query.Joins("Key").Joins("AltKey").Where("hosts.id = ?", id).First(&host) ret := r.db.Joins("Key").Joins("AltKey").Where("hosts.id = ?", id).First(&host)
if ret.Error != nil { if ret.Error != nil {
return nil, ret.Error return nil, ret.Error
} }
@ -48,13 +52,12 @@ func (r *Hosts) Get(id string) (*models.HostDecrypted, error) {
func (r *Hosts) Exists(id string) (bool, error) { func (r *Hosts) Exists(id string) (bool, error) {
var count int64 var count int64
ret := r.ACL(r.db.Model(&models.Host{}).Where("id = ?", id)).Count(&count) ret := r.db.Model(&models.Host{}).Where("id = ?", id).Count(&count)
return count > 0, ret.Error return count > 0, ret.Error
} }
func (r *Hosts) Delete(id string) error { func (r *Hosts) Delete(id string) error {
query := r.ACL(r.db) return r.db.Delete(&models.Host{Model: models.Model{ID: id}}).Error
return query.Delete(&models.Host{Model: models.Model{ID: id}}).Error
} }
func (r *Hosts) Create(item *models.Host) error { func (r *Hosts) Create(item *models.Host) error {
@ -62,15 +65,5 @@ func (r *Hosts) Create(item *models.Host) error {
} }
func (r *Hosts) Update(id string, item *models.Host) error { func (r *Hosts) Update(id string, item *models.Host) error {
query := r.ACL(r.db.Where("id = ?", id)) return r.db.Where("id = ?", id).Updates(item).Error
return query.Updates(item).Error
}
func (r *Hosts) ACL(query *gorm.DB) *gorm.DB {
if r.User.IsAdmin {
return query
}
return query.Where("hosts.owner_id = ?", r.User.ID)
} }

View File

@ -1,6 +1,7 @@
package hosts package hosts
import ( import (
"errors"
"fmt" "fmt"
"net/http" "net/http"
@ -19,10 +20,15 @@ func Router(app fiber.Router) {
} }
func getAll(c *fiber.Ctx) error { func getAll(c *fiber.Ctx) error {
teamId := c.Query("teamId")
user := utils.GetUser(c) user := utils.GetUser(c)
repo := NewRepository(&Hosts{User: user}) repo := NewRepository(&Hosts{User: user})
rows, err := repo.GetAll() if teamId != "" && !user.IsInTeam(&teamId) {
return utils.ResponseError(c, errors.New("no access"), 403)
}
rows, err := repo.GetAll(GetAllOpt{TeamID: teamId})
if err != nil { if err != nil {
return utils.ResponseError(c, err, 500) return utils.ResponseError(c, err, 500)
} }
@ -41,8 +47,13 @@ func create(c *fiber.Ctx) error {
user := utils.GetUser(c) user := utils.GetUser(c)
repo := NewRepository(&Hosts{User: user}) repo := NewRepository(&Hosts{User: user})
if body.TeamID != nil && !user.TeamCanWrite(body.TeamID) {
return utils.ResponseError(c, errors.New("no access"), 403)
}
item := &models.Host{ item := &models.Host{
OwnerID: user.ID, OwnerID: &user.ID,
TeamID: body.TeamID,
Type: body.Type, Type: body.Type,
Label: body.Label, Label: body.Label,
Host: body.Host, Host: body.Host,
@ -76,13 +87,17 @@ func update(c *fiber.Ctx) error {
repo := NewRepository(&Hosts{User: user}) repo := NewRepository(&Hosts{User: user})
id := c.Params("id") id := c.Params("id")
exist, _ := repo.Exists(id) data, _ := repo.Get(id)
if !exist { if data == nil {
return utils.ResponseError(c, fmt.Errorf("host %s not found", id), 404) return utils.ResponseError(c, errors.New("host not found"), 404)
}
if !data.CanWrite(&user.User) || !user.TeamCanWrite(body.TeamID) {
return utils.ResponseError(c, errors.New("no access"), 403)
} }
item := &models.Host{ item := &models.Host{
Model: models.Model{ID: id}, Model: models.Model{ID: id},
TeamID: body.TeamID,
Type: body.Type, Type: body.Type,
Label: body.Label, Label: body.Label,
Host: body.Host, Host: body.Host,
@ -111,9 +126,12 @@ func delete(c *fiber.Ctx) error {
repo := NewRepository(&Hosts{User: user}) repo := NewRepository(&Hosts{User: user})
id := c.Params("id") id := c.Params("id")
exist, _ := repo.Exists(id) host, _ := repo.Get(id)
if !exist { if host == nil {
return utils.ResponseError(c, fmt.Errorf("host %s not found", id), 404) return utils.ResponseError(c, errors.New("host not found"), 404)
}
if !host.CanWrite(&user.User) {
return utils.ResponseError(c, errors.New("no access"), 403)
} }
if err := repo.Delete(id); err != nil { if err := repo.Delete(id); err != nil {

View File

@ -9,7 +9,12 @@ type CreateHostSchema struct {
Port int `json:"port"` Port int `json:"port"`
Metadata datatypes.JSONMap `json:"metadata"` Metadata datatypes.JSONMap `json:"metadata"`
TeamID *string `json:"teamId"`
ParentID *string `json:"parentId"` ParentID *string `json:"parentId"`
KeyID *string `json:"keyId"` KeyID *string `json:"keyId"`
AltKeyID *string `json:"altKeyId"` AltKeyID *string `json:"altKeyId"`
} }
type GetAllOpt struct {
TeamID string
}

View File

@ -20,10 +20,16 @@ func NewRepository(r *Keychains) *Keychains {
return r return r
} }
func (r *Keychains) GetAll() ([]*models.Keychain, error) { func (r *Keychains) GetAll(opt GetAllOpt) ([]*models.Keychain, error) {
var rows []*models.Keychain query := r.db.Order("created_at DESC")
query := r.ACL(r.db.Order("created_at DESC"))
if opt.TeamID != "" {
query = query.Where("keychains.team_id = ?", opt.TeamID)
} else {
query = query.Where("keychains.owner_id = ? AND keychains.team_id IS NULL", r.User.ID)
}
var rows []*models.Keychain
ret := query.Find(&rows) ret := query.Find(&rows)
return rows, ret.Error return rows, ret.Error
} }
@ -34,9 +40,7 @@ func (r *Keychains) Create(item *models.Keychain) error {
func (r *Keychains) Get(id string) (*models.Keychain, error) { func (r *Keychains) Get(id string) (*models.Keychain, error) {
var keychain models.Keychain var keychain models.Keychain
query := r.ACL(r.db.Where("id = ?", id)) if err := r.db.Where("id = ?", id).First(&keychain).Error; err != nil {
if err := query.First(&keychain).Error; err != nil {
return nil, err return nil, err
} }
@ -45,8 +49,7 @@ func (r *Keychains) Get(id string) (*models.Keychain, error) {
func (r *Keychains) Exists(id string) (bool, error) { func (r *Keychains) Exists(id string) (bool, error) {
var count int64 var count int64
query := r.ACL(r.db.Model(&models.Keychain{}).Where("id = ?", id)) ret := r.db.Model(&models.Keychain{}).Where("id = ?", id).Count(&count)
ret := query.Count(&count)
return count > 0, ret.Error return count > 0, ret.Error
} }
@ -70,14 +73,5 @@ func (r *Keychains) GetDecrypted(id string) (*KeychainDecrypted, error) {
} }
func (r *Keychains) Update(id string, item *models.Keychain) error { func (r *Keychains) Update(id string, item *models.Keychain) error {
query := r.ACL(r.db.Where("id = ?", id)) return r.db.Where("id = ?", id).Updates(item).Error
return query.Updates(item).Error
}
func (r *Keychains) ACL(query *gorm.DB) *gorm.DB {
if r.User.IsAdmin {
return query
}
return query.Where("keychains.owner_id = ?", r.User.ID)
} }

View File

@ -1,7 +1,7 @@
package keychains package keychains
import ( import (
"fmt" "errors"
"net/http" "net/http"
"github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2"
@ -23,17 +23,22 @@ type GetAllResult struct {
} }
func getAll(c *fiber.Ctx) error { func getAll(c *fiber.Ctx) error {
teamId := c.Query("teamId")
withData := c.Query("withData") withData := c.Query("withData")
user := utils.GetUser(c) user := utils.GetUser(c)
repo := NewRepository(&Keychains{User: user}) repo := NewRepository(&Keychains{User: user})
rows, err := repo.GetAll() if teamId != "" && !user.IsInTeam(&teamId) {
return utils.ResponseError(c, errors.New("no access"), 403)
}
rows, err := repo.GetAll(GetAllOpt{TeamID: teamId})
if err != nil { if err != nil {
return utils.ResponseError(c, err, 500) return utils.ResponseError(c, err, 500)
} }
if withData != "true" { if withData != "true" || (teamId != "" && !user.TeamCanWrite(&teamId)) {
return c.JSON(fiber.Map{"rows": rows}) return c.JSON(fiber.Map{"rows": rows})
} }
@ -67,8 +72,13 @@ func create(c *fiber.Ctx) error {
user := utils.GetUser(c) user := utils.GetUser(c)
repo := NewRepository(&Keychains{User: user}) repo := NewRepository(&Keychains{User: user})
if body.TeamID != nil && !user.TeamCanWrite(body.TeamID) {
return utils.ResponseError(c, errors.New("no access"), 403)
}
item := &models.Keychain{ item := &models.Keychain{
OwnerID: user.ID, OwnerID: &user.ID,
TeamID: body.TeamID,
Type: body.Type, Type: body.Type,
Label: body.Label, Label: body.Label,
} }
@ -94,15 +104,18 @@ func update(c *fiber.Ctx) error {
repo := NewRepository(&Keychains{User: user}) repo := NewRepository(&Keychains{User: user})
id := c.Params("id") id := c.Params("id")
data, _ := repo.Get(id)
exist, _ := repo.Exists(id) if data == nil {
if !exist { return utils.ResponseError(c, errors.New("key not found"), 404)
return utils.ResponseError(c, fmt.Errorf("key %s not found", id), 404) }
if !data.CanWrite(&user.User) || !user.TeamCanWrite(body.TeamID) {
return utils.ResponseError(c, errors.New("no access"), 403)
} }
item := &models.Keychain{ item := &models.Keychain{
Type: body.Type, TeamID: body.TeamID,
Label: body.Label, Type: body.Type,
Label: body.Label,
} }
if err := item.EncryptData(body.Data); err != nil { if err := item.EncryptData(body.Data); err != nil {

View File

@ -1,7 +1,12 @@
package keychains package keychains
type CreateKeychainSchema struct { type CreateKeychainSchema struct {
Type string `json:"type"` TeamID *string `json:"teamId"`
Label string `json:"label"` Type string `json:"type"`
Data interface{} `json:"data"` Label string `json:"label"`
Data interface{} `json:"data"`
}
type GetAllOpt struct {
TeamID string
} }

View File

@ -0,0 +1,50 @@
package teams
import (
"gorm.io/gorm"
"rul.sh/vaulterm/db"
"rul.sh/vaulterm/models"
"rul.sh/vaulterm/utils"
)
type Teams struct {
db *gorm.DB
User *utils.UserContext
}
func NewRepository(r *Teams) *Teams {
if r == nil {
r = &Teams{}
}
r.db = db.Get()
return r
}
func (r *Teams) GetAll() ([]*models.Team, error) {
var rows []*models.Team
ret := r.db.Order("created_at DESC").Find(&rows)
return rows, ret.Error
}
func (r *Teams) Create(data *models.Team) error {
return r.db.Create(data).Error
}
func (r *Teams) Get(id string) (*models.Team, error) {
var data models.Team
if err := r.db.Where("id = ?", id).First(&data).Error; err != nil {
return nil, err
}
return &data, nil
}
func (r *Teams) Exists(id string) (bool, error) {
var count int64
ret := r.db.Model(&models.Team{}).Where("id = ?", id).Count(&count)
return count > 0, ret.Error
}
func (r *Teams) Update(id string, item *models.Team) error {
return r.db.Where("id = ?", id).Updates(item).Error
}

View File

@ -17,8 +17,8 @@ func HandleTerm(c *websocket.Conn) {
hostRepo := hosts.NewRepository(&hosts.Hosts{User: user}) hostRepo := hosts.NewRepository(&hosts.Hosts{User: user})
data, err := hostRepo.Get(hostId) data, err := hostRepo.Get(hostId)
if data == nil { if data == nil || !data.HasAccess(&user.User) {
log.Printf("Cannot find host! Error: %s\n", err.Error()) log.Printf("Cannot find host! %v\n", err)
c.WriteMessage(websocket.TextMessage, []byte("Host not found")) c.WriteMessage(websocket.TextMessage, []byte("Host not found"))
return return
} }

View File

@ -45,7 +45,6 @@ func Init() {
// Migrate the schema // Migrate the schema
db.AutoMigrate(Models...) db.AutoMigrate(Models...)
InitModels(db)
runSeeders(db) runSeeders(db)
} }

View File

@ -1,9 +1,6 @@
package db package db
import ( import (
"log"
"gorm.io/gorm"
"rul.sh/vaulterm/models" "rul.sh/vaulterm/models"
) )
@ -15,9 +12,3 @@ var Models = []interface{}{
&models.Team{}, &models.Team{},
&models.TeamMembers{}, &models.TeamMembers{},
} }
func InitModels(db *gorm.DB) {
if err := db.SetupJoinTable(&models.Team{}, "Members", &models.TeamMembers{}); err != nil {
log.Fatal(err)
}
}

View File

@ -66,9 +66,9 @@ func seedUsers(tx *gorm.DB) error {
} }
teamMembers := []models.TeamMembers{ teamMembers := []models.TeamMembers{
{TeamID: teams[0].ID, UserID: userList[0].ID, Role: "owner"}, {TeamID: teams[0].ID, UserID: userList[0].ID, Role: models.TeamRoleOwner},
{TeamID: teams[0].ID, UserID: userList[1].ID, Role: "admin"}, {TeamID: teams[0].ID, UserID: userList[1].ID, Role: models.TeamRoleAdmin},
{TeamID: teams[0].ID, UserID: userList[2].ID, Role: "user"}, {TeamID: teams[0].ID, UserID: userList[2].ID, Role: models.TeamRoleMember},
} }
if res := tx.Create(&teamMembers); res.Error != nil { if res := tx.Create(&teamMembers); res.Error != nil {

View File

@ -4,7 +4,6 @@ import (
"strings" "strings"
"github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2"
"gorm.io/gorm"
"rul.sh/vaulterm/db" "rul.sh/vaulterm/db"
"rul.sh/vaulterm/models" "rul.sh/vaulterm/models"
) )
@ -23,25 +22,35 @@ func Auth(c *fiber.Ctx) error {
session, _ := GetUserSession(sessionId) session, _ := GetUserSession(sessionId)
if session != nil && session.User.ID != "" { if session != nil && session.ID != "" {
c.Locals("user", &session.User) c.Locals("user", session)
c.Locals("sessionId", sessionId) c.Locals("sessionId", sessionId)
} }
return c.Next() return c.Next()
} }
func GetUserSession(sessionId string) (*models.UserSession, error) { type AuthUser struct {
var session models.UserSession models.User
SessionID string `json:"sessionId" gorm:"column:session_id"`
}
func GetUserSession(sessionId string) (*AuthUser, error) {
var session AuthUser
res := db.Get(). res := db.Get().
Joins("User"). Model(&models.User{}).
Preload("User.Teams", func(db *gorm.DB) *gorm.DB { Joins("JOIN user_sessions ON user_sessions.user_id = users.id").
return db.Select("id", "name", "icon") Preload("Teams.Team").
}). Select("users.*, user_sessions.id AS session_id").
Where("user_sessions.id = ?", sessionId). Where("user_sessions.id = ?", sessionId).
First(&session) First(&session)
return &session, res.Error if res.Error != nil || session.User.ID == "" {
return nil, res.Error
}
return &session, nil
} }
func Protected() func(c *fiber.Ctx) error { func Protected() func(c *fiber.Ctx) error {

View File

@ -1,6 +1,8 @@
package models package models
import "gorm.io/datatypes" import (
"gorm.io/datatypes"
)
const ( const (
HostTypeSSH = "ssh" HostTypeSSH = "ssh"
@ -14,8 +16,10 @@ const (
type Host struct { type Host struct {
Model Model
OwnerID string `json:"userId" gorm:"index:hosts_owner_id_idx;type:varchar(26)"` OwnerID *string `json:"userId" gorm:"type:varchar(26)"`
Owner User `json:"user" gorm:"foreignKey:OwnerID"` Owner *User `json:"user" gorm:"foreignKey:OwnerID"`
TeamID *string `json:"teamId" gorm:"type:varchar(26)"`
Team *Team `json:"team" gorm:"foreignKey:TeamID"`
Type string `json:"type" gorm:"not null;index:hosts_type_idx;type:varchar(16)"` Type string `json:"type" gorm:"not null;index:hosts_type_idx;type:varchar(16)"`
Label string `json:"label"` Label string `json:"label"`
@ -24,11 +28,11 @@ type Host struct {
OS string `json:"os" gorm:"type:varchar(32)"` OS string `json:"os" gorm:"type:varchar(32)"`
Metadata datatypes.JSONMap `json:"metadata"` Metadata datatypes.JSONMap `json:"metadata"`
ParentID *string `json:"parentId" gorm:"index:hosts_parent_id_idx;type:varchar(26)"` ParentID *string `json:"parentId" gorm:"type:varchar(26)"`
Parent *Host `json:"parent" gorm:"foreignKey:ParentID"` Parent *Host `json:"parent" gorm:"foreignKey:ParentID"`
KeyID *string `json:"keyId" gorm:"index:hosts_key_id_idx"` KeyID *string `json:"keyId" gorm:"type:varchar(26)"`
Key Keychain `json:"key" gorm:"foreignKey:KeyID"` Key Keychain `json:"key" gorm:"foreignKey:KeyID"`
AltKeyID *string `json:"altKeyId" gorm:"index:hosts_altkey_id_idx"` AltKeyID *string `json:"altKeyId" gorm:"type:varchar(26)"`
AltKey Keychain `json:"altKey" gorm:"foreignKey:AltKeyID"` AltKey Keychain `json:"altKey" gorm:"foreignKey:AltKeyID"`
Timestamps Timestamps
@ -58,13 +62,17 @@ func (h *Host) DecryptKeys() (*HostDecrypted, error) {
return res, nil return res, nil
} }
type HostHasAccessOptions struct { func (h *Host) HasAccess(user *User) bool {
UserID string if user.IsAdmin() {
}
func (h *Host) HasAccess(o HostHasAccessOptions) bool {
if o.UserID == h.OwnerID {
return true return true
} }
return false return *h.OwnerID == user.ID || user.IsInTeam(h.TeamID)
}
func (h *Host) CanWrite(user *User) bool {
if user.IsAdmin() {
return true
}
teamRole := user.GetTeamRole(h.TeamID)
return *h.OwnerID == user.ID || teamRole == TeamRoleOwner || teamRole == TeamRoleAdmin
} }

View File

@ -16,8 +16,10 @@ const (
type Keychain struct { type Keychain struct {
Model Model
OwnerID string `json:"userId" gorm:"index:hosts_owner_id_idx;type:varchar(26)"` OwnerID *string `json:"userId" gorm:"type:varchar(26)"`
Owner User `json:"user" gorm:"foreignKey:OwnerID"` Owner *User `json:"user" gorm:"foreignKey:OwnerID"`
TeamID *string `json:"teamId" gorm:"type:varchar(26)"`
Team *Team `json:"team" gorm:"foreignKey:TeamID"`
Label string `json:"label"` Label string `json:"label"`
Type string `json:"type" gorm:"not null;index:keychains_type_idx;type:varchar(12)"` Type string `json:"type" gorm:"not null;index:keychains_type_idx;type:varchar(12)"`
@ -55,3 +57,18 @@ func (k *Keychain) DecryptData(data interface{}) error {
return nil return nil
} }
func (k *Keychain) HasAccess(user *User) bool {
if user.IsAdmin() {
return true
}
return *k.OwnerID == user.ID || user.IsInTeam(k.TeamID)
}
func (k *Keychain) CanWrite(user *User) bool {
if user.IsAdmin() {
return true
}
teamRole := user.GetTeamRole(k.TeamID)
return *k.OwnerID == user.ID || teamRole == TeamRoleOwner || teamRole == TeamRoleAdmin
}

View File

@ -2,12 +2,18 @@ package models
import "time" import "time"
const (
TeamRoleOwner = "owner"
TeamRoleAdmin = "admin"
TeamRoleMember = "member"
)
type Team struct { type Team struct {
Model Model
Name string `json:"name" gorm:"type:varchar(32)"` Name string `json:"name" gorm:"type:varchar(32)"`
Icon string `json:"icon" gorm:"type:varchar(2)"` Icon string `json:"icon" gorm:"type:varchar(2)"`
Members []*User `json:"members" gorm:"many2many:team_members"` Members []*TeamMembers `json:"members" gorm:"foreignKey:TeamID"`
Timestamps Timestamps
SoftDeletes SoftDeletes

View File

@ -1,5 +1,7 @@
package models package models
import "slices"
const ( const (
UserRoleUser = "user" UserRoleUser = "user"
UserRoleAdmin = "admin" UserRoleAdmin = "admin"
@ -14,7 +16,7 @@ type User struct {
Email string `json:"email" gorm:"unique"` Email string `json:"email" gorm:"unique"`
Role string `json:"role" gorm:"default:user;not null;index:users_role_idx;type:varchar(8)"` Role string `json:"role" gorm:"default:user;not null;index:users_role_idx;type:varchar(8)"`
Teams []*Team `json:"teams" gorm:"many2many:team_members"` Teams []*TeamMembers `json:"teams" gorm:"foreignKey:UserID"`
Timestamps Timestamps
SoftDeletes SoftDeletes
@ -28,3 +30,33 @@ type UserSession struct {
Timestamps Timestamps
SoftDeletes SoftDeletes
} }
func (u *User) IsAdmin() bool {
return u.Role == UserRoleAdmin
}
func (u *User) GetTeamRole(teamId *string) string {
if u.IsAdmin() {
return TeamRoleAdmin
}
if teamId == nil {
return ""
}
idx := slices.IndexFunc(u.Teams, func(tm *TeamMembers) bool {
return tm.TeamID == *teamId
})
if idx == -1 {
return ""
}
return u.Teams[idx].Role
}
func (u *User) IsInTeam(teamId *string) bool {
role := u.GetTeamRole(teamId)
return role != ""
}
func (u *User) TeamCanWrite(teamId *string) bool {
role := u.GetTeamRole(teamId)
return role == TeamRoleAdmin || role == TeamRoleOwner
}

View File

@ -3,33 +3,17 @@ package utils
import ( import (
"github.com/gofiber/contrib/websocket" "github.com/gofiber/contrib/websocket"
"github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2"
"rul.sh/vaulterm/models" "rul.sh/vaulterm/middleware"
) )
type UserContext struct { type UserContext = middleware.AuthUser
*models.User
IsAdmin bool `json:"isAdmin"`
}
func getUserData(user *models.User) *UserContext {
isAdmin := false
if user.Role == models.UserRoleAdmin {
isAdmin = true
}
return &UserContext{
User: user,
IsAdmin: isAdmin,
}
}
func GetUser(c *fiber.Ctx) *UserContext { func GetUser(c *fiber.Ctx) *UserContext {
user := c.Locals("user").(*models.User) user, _ := c.Locals("user").(*UserContext)
return getUserData(user) return user
} }
func GetUserWs(c *websocket.Conn) *UserContext { func GetUserWs(c *websocket.Conn) *UserContext {
user := c.Locals("user").(*models.User) user, _ := c.Locals("user").(*UserContext)
return getUserData(user) return user
} }