feat: add auth, hosts, & keychains ownership

This commit is contained in:
Khairul Hidayat 2024-11-10 18:49:18 +07:00
parent b2cc6778a6
commit 38e81049a1
31 changed files with 579 additions and 92 deletions

View File

@ -12,6 +12,7 @@ import { router, usePathname, useRootNavigationState } from "expo-router";
import { useAuthStore } from "@/stores/auth";
import { PortalProvider } from "tamagui";
import { queryClient } from "@/lib/api";
import { useAppStore } from "@/stores/app";
type Props = PropsWithChildren;
@ -53,16 +54,24 @@ const AuthProvider = () => {
const pathname = usePathname();
const rootNavigationState = useRootNavigationState();
const { isLoggedIn } = useAuthStore();
const { curServer } = useAppStore();
useEffect(() => {
if (!rootNavigationState?.key) {
return;
}
if (!pathname.startsWith("/auth") && !isLoggedIn) {
if (!curServer && !pathname.startsWith("/server")) {
router.replace("/server");
return;
}
const isProtected = !["/auth", "/server"].find((path) =>
pathname.startsWith(path)
);
if (isProtected && !isLoggedIn) {
router.replace("/auth/login");
} else if (pathname.startsWith("/auth") && isLoggedIn) {
router.replace("/");
}
}, [pathname, rootNavigationState, isLoggedIn]);

View File

@ -1,14 +1,3 @@
import { View, Text, Button } from "tamagui";
import React from "react";
import authStore from "@/stores/auth";
import LoginPage from "@/pages/auth/login";
export default function LoginPage() {
return (
<View>
<Text>LoginPage</Text>
<Button onPress={() => authStore.setState({ token: "123" })}>
Login
</Button>
</View>
);
}
export default LoginPage;

View File

@ -1,9 +1,15 @@
import React from "react";
import { Redirect } from "expo-router";
import { useTermSession } from "@/stores/terminal-sessions";
import { useAppStore } from "@/stores/app";
export default function index() {
const { sessions, curSession } = useTermSession();
const { servers, curServer } = useAppStore();
if (!servers.length || !curServer) {
return <Redirect href="/server" />;
}
return (
<Redirect

View File

@ -0,0 +1,3 @@
import ServerPage from "@/pages/server/page";
export default ServerPage;

View File

@ -2,6 +2,7 @@ import React from "react";
import Terminal from "./terminal";
import { BASE_WS_URL } from "@/lib/api";
import VNCViewer from "./vncviewer";
import { useAuthStore } from "@/stores/auth";
type SSHSessionProps = {
type: "ssh";
@ -28,7 +29,8 @@ export type InteractiveSessionProps = {
} & (SSHSessionProps | PVESessionProps | IncusSessionProps);
const InteractiveSession = ({ type, params }: InteractiveSessionProps) => {
const query = new URLSearchParams(params);
const { token } = useAuthStore();
const query = new URLSearchParams({ ...params, sid: token || "" });
const url = `${BASE_WS_URL}/ws/term?${query}`;
switch (type) {

View File

@ -0,0 +1,29 @@
import React from "react";
import { Button, GetProps } from "tamagui";
import Icons from "../ui/icons";
import useThemeStore from "@/stores/theme";
type Props = GetProps<typeof Button> & {
iconSize?: number;
};
const ThemeSwitcher = ({ iconSize = 24, ...props }: Props) => {
const { theme, toggle } = useThemeStore();
return (
<Button
icon={
<Icons
name={
theme === "light" ? "white-balance-sunny" : "moon-waning-crescent"
}
size={iconSize}
/>
}
onPress={toggle}
{...props}
/>
);
};
export default ThemeSwitcher;

View File

@ -4,11 +4,21 @@ import { Label, Text, View, XStack } from "tamagui";
type FormFieldProps = ComponentPropsWithoutRef<typeof XStack> & {
label?: string;
htmlFor?: string;
vertical?: boolean;
};
const FormField = ({ label, htmlFor, ...props }: FormFieldProps) => {
const FormField = ({
label,
htmlFor,
vertical = false,
...props
}: FormFieldProps) => {
return (
<XStack alignItems="flex-start" {...props}>
<XStack
flexDirection={vertical ? "column" : "row"}
alignItems={vertical ? "stretch" : "flex-start"}
{...props}
>
<Label htmlFor={htmlFor} w={120} $xs={{ w: 100 }}>
{label}
</Label>

View File

@ -1,3 +1,4 @@
import authStore from "@/stores/auth";
import { QueryClient } from "@tanstack/react-query";
import { ofetch } from "ofetch";
@ -6,7 +7,18 @@ export const BASE_WS_URL = BASE_API_URL.replace("http", "ws");
const api = ofetch.create({
baseURL: BASE_API_URL,
onRequest: (config) => {
const authToken = authStore.getState().token;
if (authToken) {
config.options.headers.set("Authorization", `Bearer ${authToken}`);
}
},
onResponseError: (error) => {
if (error.response.status === 401 && !!authStore.getState().token) {
authStore.setState({ token: null });
throw new Error("Unauthorized");
}
if (error.response._data) {
const message = error.response._data.message;
throw new Error(message || "Something went wrong");

View File

@ -0,0 +1,92 @@
import { Text, ScrollView, Card, Separator } from "tamagui";
import React from "react";
import FormField from "@/components/ui/form";
import { InputField } from "@/components/ui/input";
import { useZForm } from "@/hooks/useZForm";
import { router, Stack } from "expo-router";
import Button from "@/components/ui/button";
import ThemeSwitcher from "@/components/containers/theme-switcher";
import { useMutation } from "@tanstack/react-query";
import { z } from "zod";
import { ErrorAlert } from "@/components/ui/alert";
import { loginResultSchema, loginSchema } from "./schema";
import api from "@/lib/api";
import Icons from "@/components/ui/icons";
import authStore from "@/stores/auth";
export default function LoginPage() {
const form = useZForm(loginSchema);
const login = useMutation({
mutationFn: async (body: z.infer<typeof loginSchema>) => {
const res = await api("/auth/login", { method: "POST", body });
const { data } = loginResultSchema.safeParse(res);
if (!data) {
throw new Error("Invalid response!");
}
return data;
},
onSuccess(data) {
authStore.setState({ token: data.sessionId });
router.replace("/");
},
});
const onSubmit = form.handleSubmit((values) => {
login.mutate(values);
});
return (
<>
<Stack.Screen
options={{
contentStyle: {
width: "100%",
maxWidth: 600,
marginHorizontal: "auto",
},
title: "Login",
headerRight: () => (
<ThemeSwitcher bg="$colorTransparent" $gtSm={{ mr: "$3" }} />
),
}}
/>
<ScrollView
contentContainerStyle={{
padding: "$4",
pb: "$12",
justifyContent: "center",
flexGrow: 1,
}}
>
<Card bordered p="$4" gap="$4">
<Text fontSize="$8">Login</Text>
<ErrorAlert error={login.error} />
<FormField vertical label="Username/Email">
<InputField form={form} name="username" />
</FormField>
<FormField vertical label="Password">
<InputField form={form} name="password" secureTextEntry />
</FormField>
<Separator />
<Button
icon={<Icons name="lock" size={16} />}
onPress={onSubmit}
isLoading={login.isPending}
>
Connect
</Button>
<Button onPress={() => router.push("/server")} bg="$colorTransparent">
Change Server
</Button>
</Card>
</ScrollView>
</>
);
}

View File

@ -0,0 +1,10 @@
import { z } from "zod";
export const loginSchema = z.object({
username: z.string(),
password: z.string(),
});
export const loginResultSchema = z.object({
sessionId: z.string().min(40),
});

View File

@ -0,0 +1,78 @@
import { View, Text, ScrollView, Card } from "tamagui";
import React from "react";
import FormField from "@/components/ui/form";
import { InputField } from "@/components/ui/input";
import { useZForm } from "@/hooks/useZForm";
import { getServerResultSchema, serverSchema } from "./schema";
import { router, Stack } from "expo-router";
import Button from "@/components/ui/button";
import ThemeSwitcher from "@/components/containers/theme-switcher";
import { useMutation } from "@tanstack/react-query";
import { ofetch } from "ofetch";
import { z } from "zod";
import { ErrorAlert } from "@/components/ui/alert";
import { addServer } from "@/stores/app";
export default function ServerPage() {
const form = useZForm(serverSchema);
const serverConnect = useMutation({
mutationFn: async (body: z.infer<typeof serverSchema>) => {
const res = await ofetch(body.url + "/server");
const { data } = getServerResultSchema.safeParse(res);
if (!data) {
throw new Error("Invalid server");
}
return data;
},
onSuccess(data, payload) {
addServer({ url: payload.url, name: data.name }, true);
router.replace("/auth/login");
},
});
const onSubmit = form.handleSubmit((values) => {
serverConnect.mutate(values);
});
return (
<>
<Stack.Screen
options={{
contentStyle: {
width: "100%",
maxWidth: 600,
marginHorizontal: "auto",
},
title: "Vaulterm",
headerRight: () => (
<ThemeSwitcher bg="$colorTransparent" $gtSm={{ mr: "$3" }} />
),
}}
/>
<ScrollView
contentContainerStyle={{
padding: "$4",
pb: "$12",
justifyContent: "center",
flexGrow: 1,
}}
>
<Card bordered p="$4" gap="$4">
<Text fontSize="$8">Connect to Server</Text>
<ErrorAlert error={serverConnect.error} />
<FormField vertical label="URL">
<InputField form={form} name="url" placeholder="https://" />
</FormField>
<Button onPress={onSubmit} isLoading={serverConnect.isPending}>
Connect
</Button>
</Card>
</ScrollView>
</>
);
}

View File

@ -0,0 +1,10 @@
import { z } from "zod";
export const serverSchema = z.object({
url: z.string().url("Invalid URL"),
});
export const getServerResultSchema = z.object({
name: z.string(),
version: z.string().min(1),
});

62
frontend/stores/app.ts Normal file
View File

@ -0,0 +1,62 @@
import { createStore, useStore } from "zustand";
import { persist, createJSONStorage } from "zustand/middleware";
import AsyncStorage from "@react-native-async-storage/async-storage";
type AppServer = {
name?: string;
url: string;
};
type AppStore = {
servers: AppServer[];
curServerIdx?: number | null;
};
const appStore = createStore(
persist<AppStore>(
() => ({
servers: [],
curServerIdx: null,
}),
{
name: "vaulterm:app",
storage: createJSONStorage(() => AsyncStorage),
}
)
);
export function addServer(srv: AppServer, setActive?: boolean) {
const curServers = appStore.getState().servers;
const isExist = curServers.findIndex((s) => s.url === srv.url);
if (isExist >= 0) {
setActiveServer(isExist);
return;
}
appStore.setState((state) => ({
servers: [...state.servers, srv],
curServerIdx: setActive ? state.servers.length : state.curServerIdx,
}));
}
export function removeServer(idx: number) {
appStore.setState((state) => ({
servers: state.servers.filter((_, i) => i !== idx),
curServerIdx: state.curServerIdx === idx ? null : state.curServerIdx,
}));
}
export function setActiveServer(idx: number) {
appStore.setState({ curServerIdx: idx });
}
export const useAppStore = () => {
const state = useStore(appStore);
const curServer =
state.curServerIdx != null ? state.servers[state.curServerIdx] : null;
return { ...state, curServer };
};
export default appStore;

View File

@ -12,7 +12,7 @@ const authStore = createStore(
token: null,
}),
{
name: "auth",
name: "vaulterm:auth",
storage: createJSONStorage(() => AsyncStorage),
}
)

View File

@ -38,6 +38,9 @@ export const useTermSession = create(
set({ curSession: idx });
},
}),
{ name: "term-sessions", storage: createJSONStorage(() => AsyncStorage) }
{
name: "vaulterm:term-sessions",
storage: createJSONStorage(() => AsyncStorage),
}
)
);

View File

@ -20,7 +20,7 @@ const useThemeStore = create(
},
}),
{
name: "theme",
name: "vaulterm:theme",
storage: createJSONStorage(() => AsyncStorage),
}
)

View File

@ -5,10 +5,8 @@ import (
"github.com/gofiber/fiber/v2/middleware/cors"
"github.com/joho/godotenv"
"rul.sh/vaulterm/app/auth"
"rul.sh/vaulterm/app/hosts"
"rul.sh/vaulterm/app/keychains"
"rul.sh/vaulterm/app/ws"
"rul.sh/vaulterm/db"
"rul.sh/vaulterm/middleware"
)
func NewApp() *fiber.App {
@ -18,20 +16,26 @@ func NewApp() *fiber.App {
// Create fiber app
app := fiber.New(fiber.Config{ErrorHandler: ErrorHandler})
// Middlewares
app.Use(cors.New())
// Init app routes
auth.Router(app)
hosts.Router(app)
keychains.Router(app)
ws.Router(app)
// Server info
app.Get("/server", func(c *fiber.Ctx) error {
return c.JSON(&fiber.Map{
"name": "Vaulterm",
"version": "0.0.1",
})
})
// Health check
app.Get("/health-check", func(c *fiber.Ctx) error {
return c.SendString("OK")
})
app.Use(middleware.Auth)
auth.Router(app)
app.Use(middleware.Protected())
InitRouter(app)
return app
}

View File

@ -9,7 +9,7 @@ import (
type Auth struct{ db *gorm.DB }
func NewAuthRepository() *Auth {
func NewRepository() *Auth {
return &Auth{db: db.Get()}
}

View File

@ -1,10 +1,9 @@
package auth
import (
"strings"
"github.com/gofiber/fiber/v2"
"rul.sh/vaulterm/lib"
"rul.sh/vaulterm/middleware"
"rul.sh/vaulterm/utils"
)
@ -12,12 +11,12 @@ func Router(app *fiber.App) {
router := app.Group("/auth")
router.Post("/login", login)
router.Get("/user", getUser)
router.Post("/logout", logout)
router.Get("/user", middleware.Protected(), getUser)
router.Post("/logout", middleware.Protected(), logout)
}
func login(c *fiber.Ctx) error {
repo := NewAuthRepository()
repo := NewRepository()
var body LoginSchema
if err := c.BodyParser(&body); err != nil {
@ -54,32 +53,15 @@ func login(c *fiber.Ctx) error {
}
func getUser(c *fiber.Ctx) error {
auth := c.Get("Authorization")
var sessionId string
if auth != "" {
sessionId = strings.Split(auth, " ")[1]
}
repo := NewAuthRepository()
session, err := repo.GetSession(sessionId)
if err != nil {
return utils.ResponseError(c, err, 500)
}
return c.JSON(session)
user := utils.GetUser(c)
return c.JSON(user)
}
func logout(c *fiber.Ctx) error {
auth := c.Get("Authorization")
force := c.Query("force")
var sessionId string
sessionId := c.Locals("sessionId").(string)
if auth != "" {
sessionId = strings.Split(auth, " ")[1]
}
repo := NewAuthRepository()
repo := NewRepository()
err := repo.RemoveUserSession(sessionId, force == "true")
if err != nil {

View File

@ -4,25 +4,36 @@ import (
"gorm.io/gorm"
"rul.sh/vaulterm/db"
"rul.sh/vaulterm/models"
"rul.sh/vaulterm/utils"
)
type Hosts struct{ db *gorm.DB }
type Hosts struct {
db *gorm.DB
User *utils.UserContext
}
func NewRepository() *Hosts {
return &Hosts{db: db.Get()}
func NewRepository(r *Hosts) *Hosts {
if r == nil {
r = &Hosts{}
}
r.db = db.Get()
return r
}
func (r *Hosts) GetAll() ([]*models.Host, error) {
query := r.ACL(r.db.Order("id DESC"))
var rows []*models.Host
ret := r.db.Order("id DESC").Find(&rows)
ret := query.Find(&rows)
return rows, ret.Error
}
func (r *Hosts) Get(id string) (*models.HostDecrypted, error) {
var host models.Host
ret := r.db.Joins("Key").Joins("AltKey").Where("hosts.id = ?", id).First(&host)
query := r.ACL(r.db)
var host models.Host
ret := query.Joins("Key").Joins("AltKey").Where("hosts.id = ?", id).First(&host)
if ret.Error != nil {
return nil, ret.Error
}
@ -37,12 +48,13 @@ func (r *Hosts) Get(id string) (*models.HostDecrypted, error) {
func (r *Hosts) Exists(id string) (bool, error) {
var count int64
ret := r.db.Model(&models.Host{}).Where("id = ?", id).Count(&count)
ret := r.ACL(r.db.Model(&models.Host{}).Where("id = ?", id)).Count(&count)
return count > 0, ret.Error
}
func (r *Hosts) Delete(id string) error {
return r.db.Delete(&models.Host{Model: models.Model{ID: id}}).Error
query := r.ACL(r.db)
return query.Delete(&models.Host{Model: models.Model{ID: id}}).Error
}
func (r *Hosts) Create(item *models.Host) error {
@ -50,5 +62,15 @@ func (r *Hosts) Create(item *models.Host) error {
}
func (r *Hosts) Update(id string, item *models.Host) error {
return r.db.Where("id = ?", id).Updates(item).Error
query := r.ACL(r.db.Where("id = ?", id))
return query.Updates(item).Error
}
func (r *Hosts) ACL(query *gorm.DB) *gorm.DB {
if r.User.IsAdmin {
return query
}
return query.Where("hosts.owner_id = ?", r.User.ID)
}

View File

@ -9,7 +9,7 @@ import (
"rul.sh/vaulterm/utils"
)
func Router(app *fiber.App) {
func Router(app fiber.Router) {
router := app.Group("/hosts")
router.Get("/", getAll)
@ -19,7 +19,9 @@ func Router(app *fiber.App) {
}
func getAll(c *fiber.Ctx) error {
repo := NewRepository()
user := utils.GetUser(c)
repo := NewRepository(&Hosts{User: user})
rows, err := repo.GetAll()
if err != nil {
return utils.ResponseError(c, err, 500)
@ -36,8 +38,11 @@ func create(c *fiber.Ctx) error {
return utils.ResponseError(c, err, 500)
}
repo := NewRepository()
user := utils.GetUser(c)
repo := NewRepository(&Hosts{User: user})
item := &models.Host{
OwnerID: user.ID,
Type: body.Type,
Label: body.Label,
Host: body.Host,
@ -48,7 +53,7 @@ func create(c *fiber.Ctx) error {
AltKeyID: body.AltKeyID,
}
osName, err := tryConnect(item)
osName, err := tryConnect(c, item)
if err != nil {
return utils.ResponseError(c, fmt.Errorf("cannot connect to the host: %s", err), 500)
}
@ -67,7 +72,8 @@ func update(c *fiber.Ctx) error {
return utils.ResponseError(c, err, 500)
}
repo := NewRepository()
user := utils.GetUser(c)
repo := NewRepository(&Hosts{User: user})
id := c.Params("id")
exist, _ := repo.Exists(id)
@ -87,7 +93,7 @@ func update(c *fiber.Ctx) error {
AltKeyID: body.AltKeyID,
}
osName, err := tryConnect(item)
osName, err := tryConnect(c, item)
if err != nil {
return utils.ResponseError(c, fmt.Errorf("cannot connect to the host: %s", err), 500)
}
@ -101,7 +107,8 @@ func update(c *fiber.Ctx) error {
}
func delete(c *fiber.Ctx) error {
repo := NewRepository()
user := utils.GetUser(c)
repo := NewRepository(&Hosts{User: user})
id := c.Params("id")
exist, _ := repo.Exists(id)

View File

@ -3,13 +3,16 @@ package hosts
import (
"fmt"
"github.com/gofiber/fiber/v2"
"rul.sh/vaulterm/app/keychains"
"rul.sh/vaulterm/lib"
"rul.sh/vaulterm/models"
"rul.sh/vaulterm/utils"
)
func tryConnect(host *models.Host) (string, error) {
keyRepo := keychains.NewRepository()
func tryConnect(c *fiber.Ctx, host *models.Host) (string, error) {
user := utils.GetUser(c)
keyRepo := keychains.NewRepository(&keychains.Keychains{User: user})
var key map[string]interface{}
var altKey map[string]interface{}

View File

@ -4,18 +4,27 @@ import (
"gorm.io/gorm"
"rul.sh/vaulterm/db"
"rul.sh/vaulterm/models"
"rul.sh/vaulterm/utils"
)
type Keychains struct{ db *gorm.DB }
type Keychains struct {
db *gorm.DB
User *utils.UserContext
}
func NewRepository() *Keychains {
return &Keychains{db: db.Get()}
func NewRepository(r *Keychains) *Keychains {
if r == nil {
r = &Keychains{}
}
r.db = db.Get()
return r
}
func (r *Keychains) GetAll() ([]*models.Keychain, error) {
var rows []*models.Keychain
ret := r.db.Order("created_at DESC").Find(&rows)
query := r.ACL(r.db.Order("created_at DESC"))
ret := query.Find(&rows)
return rows, ret.Error
}
@ -25,7 +34,9 @@ func (r *Keychains) Create(item *models.Keychain) error {
func (r *Keychains) Get(id string) (*models.Keychain, error) {
var keychain models.Keychain
if err := r.db.Where("id = ?", id).First(&keychain).Error; err != nil {
query := r.ACL(r.db.Where("id = ?", id))
if err := query.First(&keychain).Error; err != nil {
return nil, err
}
@ -34,7 +45,8 @@ func (r *Keychains) Get(id string) (*models.Keychain, error) {
func (r *Keychains) Exists(id string) (bool, error) {
var count int64
ret := r.db.Model(&models.Keychain{}).Where("id = ?", id).Count(&count)
query := r.ACL(r.db.Model(&models.Keychain{}).Where("id = ?", id))
ret := query.Count(&count)
return count > 0, ret.Error
}
@ -58,5 +70,14 @@ func (r *Keychains) GetDecrypted(id string) (*KeychainDecrypted, error) {
}
func (r *Keychains) Update(id string, item *models.Keychain) error {
return r.db.Where("id = ?", id).Updates(item).Error
query := r.ACL(r.db.Where("id = ?", id))
return query.Updates(item).Error
}
func (r *Keychains) ACL(query *gorm.DB) *gorm.DB {
if r.User.IsAdmin {
return query
}
return query.Where("keychains.owner_id = ?", r.User.ID)
}

View File

@ -9,7 +9,7 @@ import (
"rul.sh/vaulterm/utils"
)
func Router(app *fiber.App) {
func Router(app fiber.Router) {
router := app.Group("/keychains")
router.Get("/", getAll)
@ -25,7 +25,9 @@ type GetAllResult struct {
func getAll(c *fiber.Ctx) error {
withData := c.Query("withData")
repo := NewRepository()
user := utils.GetUser(c)
repo := NewRepository(&Keychains{User: user})
rows, err := repo.GetAll()
if err != nil {
return utils.ResponseError(c, err, 500)
@ -62,11 +64,13 @@ func create(c *fiber.Ctx) error {
return utils.ResponseError(c, err, 500)
}
repo := NewRepository()
user := utils.GetUser(c)
repo := NewRepository(&Keychains{User: user})
item := &models.Keychain{
Type: body.Type,
Label: body.Label,
OwnerID: user.ID,
Type: body.Type,
Label: body.Label,
}
if err := item.EncryptData(body.Data); err != nil {
@ -86,7 +90,9 @@ func update(c *fiber.Ctx) error {
return utils.ResponseError(c, err, 500)
}
repo := NewRepository()
user := utils.GetUser(c)
repo := NewRepository(&Keychains{User: user})
id := c.Params("id")
exist, _ := repo.Exists(id)

23
server/app/router.go Normal file
View File

@ -0,0 +1,23 @@
package app
import (
"github.com/gofiber/fiber/v2"
"rul.sh/vaulterm/app/hosts"
"rul.sh/vaulterm/app/keychains"
"rul.sh/vaulterm/app/ws"
)
func InitRouter(app *fiber.App) {
// App route list
routes := []Router{
hosts.Router,
keychains.Router,
ws.Router,
}
for _, route := range routes {
route(app)
}
}
type Router func(app fiber.Router)

View File

@ -5,7 +5,7 @@ import (
"github.com/gofiber/fiber/v2"
)
func Router(app *fiber.App) {
func Router(app fiber.Router) {
router := app.Group("/ws")
router.Use(func(c *fiber.Ctx) error {

View File

@ -13,7 +13,8 @@ import (
func HandleTerm(c *websocket.Conn) {
hostId := c.Query("hostId")
hostRepo := hosts.NewRepository()
user := utils.GetUserWs(c)
hostRepo := hosts.NewRepository(&hosts.Hosts{User: user})
data, err := hostRepo.Get(hostId)
if data == nil {

51
server/middleware/auth.go Normal file
View File

@ -0,0 +1,51 @@
package middleware
import (
"strings"
"github.com/gofiber/fiber/v2"
"rul.sh/vaulterm/db"
"rul.sh/vaulterm/models"
)
func Auth(c *fiber.Ctx) error {
authHeader := c.Get("Authorization")
var sessionId string
if authHeader != "" {
sessionId = strings.Split(authHeader, " ")[1]
}
if strings.HasPrefix(c.Path(), "/ws") && sessionId == "" {
sessionId = c.Query("sid")
}
session, _ := GetUserSession(sessionId)
if session != nil && session.User.ID != "" {
c.Locals("user", &session.User)
c.Locals("sessionId", sessionId)
}
return c.Next()
}
func GetUserSession(sessionId string) (*models.UserSession, error) {
var session models.UserSession
res := db.Get().Joins("User").Where("user_sessions.id = ?", sessionId).First(&session)
return &session, res.Error
}
func Protected() func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
user := c.Locals("user")
if user == nil {
return &fiber.Error{
Code: fiber.StatusUnauthorized,
Message: "Unauthorized",
}
}
return c.Next()
}
}

View File

@ -14,6 +14,9 @@ const (
type Host struct {
Model
OwnerID string `json:"userId" gorm:"index:hosts_owner_id_idx;type:varchar(26)"`
Owner User `json:"user" gorm:"foreignKey:OwnerID"`
Type string `json:"type" gorm:"not null;index:hosts_type_idx;type:varchar(16)"`
Label string `json:"label"`
Host string `json:"host" gorm:"type:varchar(64)"`
@ -54,3 +57,14 @@ func (h *Host) DecryptKeys() (*HostDecrypted, error) {
return res, nil
}
type HostHasAccessOptions struct {
UserID string
}
func (h *Host) HasAccess(o HostHasAccessOptions) bool {
if o.UserID == h.OwnerID {
return true
}
return false
}

View File

@ -16,6 +16,9 @@ const (
type Keychain struct {
Model
OwnerID string `json:"userId" gorm:"index:hosts_owner_id_idx;type:varchar(26)"`
Owner User `json:"user" gorm:"foreignKey:OwnerID"`
Label string `json:"label"`
Type string `json:"type" gorm:"not null;index:keychains_type_idx;type:varchar(12)"`
Data string `json:"-" gorm:"type:text"`

35
server/utils/context.go Normal file
View File

@ -0,0 +1,35 @@
package utils
import (
"github.com/gofiber/contrib/websocket"
"github.com/gofiber/fiber/v2"
"rul.sh/vaulterm/models"
)
type UserContext struct {
*models.User
IsAdmin bool
}
func getUserData(user *models.User) *UserContext {
isAdmin := false
if user.Role == models.UserRoleAdmin {
isAdmin = true
}
return &UserContext{
User: user,
IsAdmin: isAdmin,
}
}
func GetUser(c *fiber.Ctx) *UserContext {
user := c.Locals("user").(*models.User)
return getUserData(user)
}
func GetUserWs(c *websocket.Conn) *UserContext {
user := c.Locals("user").(*models.User)
return getUserData(user)
}