feat: add auth, hosts, & keychains ownership

This commit is contained in:
Khairul Hidayat 2024-11-10 18:49:18 +07:00
parent b2cc6778a6
commit 38e81049a1
31 changed files with 579 additions and 92 deletions

View File

@ -12,6 +12,7 @@ import { router, usePathname, useRootNavigationState } from "expo-router";
import { useAuthStore } from "@/stores/auth"; import { useAuthStore } from "@/stores/auth";
import { PortalProvider } from "tamagui"; import { PortalProvider } from "tamagui";
import { queryClient } from "@/lib/api"; import { queryClient } from "@/lib/api";
import { useAppStore } from "@/stores/app";
type Props = PropsWithChildren; type Props = PropsWithChildren;
@ -53,16 +54,24 @@ const AuthProvider = () => {
const pathname = usePathname(); const pathname = usePathname();
const rootNavigationState = useRootNavigationState(); const rootNavigationState = useRootNavigationState();
const { isLoggedIn } = useAuthStore(); const { isLoggedIn } = useAuthStore();
const { curServer } = useAppStore();
useEffect(() => { useEffect(() => {
if (!rootNavigationState?.key) { if (!rootNavigationState?.key) {
return; return;
} }
if (!pathname.startsWith("/auth") && !isLoggedIn) { if (!curServer && !pathname.startsWith("/server")) {
router.replace("/server");
return;
}
const isProtected = !["/auth", "/server"].find((path) =>
pathname.startsWith(path)
);
if (isProtected && !isLoggedIn) {
router.replace("/auth/login"); router.replace("/auth/login");
} else if (pathname.startsWith("/auth") && isLoggedIn) {
router.replace("/");
} }
}, [pathname, rootNavigationState, isLoggedIn]); }, [pathname, rootNavigationState, isLoggedIn]);

View File

@ -1,14 +1,3 @@
import { View, Text, Button } from "tamagui"; import LoginPage from "@/pages/auth/login";
import React from "react";
import authStore from "@/stores/auth";
export default function LoginPage() { export default LoginPage;
return (
<View>
<Text>LoginPage</Text>
<Button onPress={() => authStore.setState({ token: "123" })}>
Login
</Button>
</View>
);
}

View File

@ -1,9 +1,15 @@
import React from "react"; import React from "react";
import { Redirect } from "expo-router"; import { Redirect } from "expo-router";
import { useTermSession } from "@/stores/terminal-sessions"; import { useTermSession } from "@/stores/terminal-sessions";
import { useAppStore } from "@/stores/app";
export default function index() { export default function index() {
const { sessions, curSession } = useTermSession(); const { sessions, curSession } = useTermSession();
const { servers, curServer } = useAppStore();
if (!servers.length || !curServer) {
return <Redirect href="/server" />;
}
return ( return (
<Redirect <Redirect

View File

@ -0,0 +1,3 @@
import ServerPage from "@/pages/server/page";
export default ServerPage;

View File

@ -2,6 +2,7 @@ import React from "react";
import Terminal from "./terminal"; import Terminal from "./terminal";
import { BASE_WS_URL } from "@/lib/api"; import { BASE_WS_URL } from "@/lib/api";
import VNCViewer from "./vncviewer"; import VNCViewer from "./vncviewer";
import { useAuthStore } from "@/stores/auth";
type SSHSessionProps = { type SSHSessionProps = {
type: "ssh"; type: "ssh";
@ -28,7 +29,8 @@ export type InteractiveSessionProps = {
} & (SSHSessionProps | PVESessionProps | IncusSessionProps); } & (SSHSessionProps | PVESessionProps | IncusSessionProps);
const InteractiveSession = ({ type, params }: InteractiveSessionProps) => { const InteractiveSession = ({ type, params }: InteractiveSessionProps) => {
const query = new URLSearchParams(params); const { token } = useAuthStore();
const query = new URLSearchParams({ ...params, sid: token || "" });
const url = `${BASE_WS_URL}/ws/term?${query}`; const url = `${BASE_WS_URL}/ws/term?${query}`;
switch (type) { switch (type) {

View File

@ -0,0 +1,29 @@
import React from "react";
import { Button, GetProps } from "tamagui";
import Icons from "../ui/icons";
import useThemeStore from "@/stores/theme";
type Props = GetProps<typeof Button> & {
iconSize?: number;
};
const ThemeSwitcher = ({ iconSize = 24, ...props }: Props) => {
const { theme, toggle } = useThemeStore();
return (
<Button
icon={
<Icons
name={
theme === "light" ? "white-balance-sunny" : "moon-waning-crescent"
}
size={iconSize}
/>
}
onPress={toggle}
{...props}
/>
);
};
export default ThemeSwitcher;

View File

@ -4,11 +4,21 @@ import { Label, Text, View, XStack } from "tamagui";
type FormFieldProps = ComponentPropsWithoutRef<typeof XStack> & { type FormFieldProps = ComponentPropsWithoutRef<typeof XStack> & {
label?: string; label?: string;
htmlFor?: string; htmlFor?: string;
vertical?: boolean;
}; };
const FormField = ({ label, htmlFor, ...props }: FormFieldProps) => { const FormField = ({
label,
htmlFor,
vertical = false,
...props
}: FormFieldProps) => {
return ( return (
<XStack alignItems="flex-start" {...props}> <XStack
flexDirection={vertical ? "column" : "row"}
alignItems={vertical ? "stretch" : "flex-start"}
{...props}
>
<Label htmlFor={htmlFor} w={120} $xs={{ w: 100 }}> <Label htmlFor={htmlFor} w={120} $xs={{ w: 100 }}>
{label} {label}
</Label> </Label>

View File

@ -1,3 +1,4 @@
import authStore from "@/stores/auth";
import { QueryClient } from "@tanstack/react-query"; import { QueryClient } from "@tanstack/react-query";
import { ofetch } from "ofetch"; import { ofetch } from "ofetch";
@ -6,7 +7,18 @@ export const BASE_WS_URL = BASE_API_URL.replace("http", "ws");
const api = ofetch.create({ const api = ofetch.create({
baseURL: BASE_API_URL, baseURL: BASE_API_URL,
onRequest: (config) => {
const authToken = authStore.getState().token;
if (authToken) {
config.options.headers.set("Authorization", `Bearer ${authToken}`);
}
},
onResponseError: (error) => { onResponseError: (error) => {
if (error.response.status === 401 && !!authStore.getState().token) {
authStore.setState({ token: null });
throw new Error("Unauthorized");
}
if (error.response._data) { if (error.response._data) {
const message = error.response._data.message; const message = error.response._data.message;
throw new Error(message || "Something went wrong"); throw new Error(message || "Something went wrong");

View File

@ -0,0 +1,92 @@
import { Text, ScrollView, Card, Separator } from "tamagui";
import React from "react";
import FormField from "@/components/ui/form";
import { InputField } from "@/components/ui/input";
import { useZForm } from "@/hooks/useZForm";
import { router, Stack } from "expo-router";
import Button from "@/components/ui/button";
import ThemeSwitcher from "@/components/containers/theme-switcher";
import { useMutation } from "@tanstack/react-query";
import { z } from "zod";
import { ErrorAlert } from "@/components/ui/alert";
import { loginResultSchema, loginSchema } from "./schema";
import api from "@/lib/api";
import Icons from "@/components/ui/icons";
import authStore from "@/stores/auth";
export default function LoginPage() {
const form = useZForm(loginSchema);
const login = useMutation({
mutationFn: async (body: z.infer<typeof loginSchema>) => {
const res = await api("/auth/login", { method: "POST", body });
const { data } = loginResultSchema.safeParse(res);
if (!data) {
throw new Error("Invalid response!");
}
return data;
},
onSuccess(data) {
authStore.setState({ token: data.sessionId });
router.replace("/");
},
});
const onSubmit = form.handleSubmit((values) => {
login.mutate(values);
});
return (
<>
<Stack.Screen
options={{
contentStyle: {
width: "100%",
maxWidth: 600,
marginHorizontal: "auto",
},
title: "Login",
headerRight: () => (
<ThemeSwitcher bg="$colorTransparent" $gtSm={{ mr: "$3" }} />
),
}}
/>
<ScrollView
contentContainerStyle={{
padding: "$4",
pb: "$12",
justifyContent: "center",
flexGrow: 1,
}}
>
<Card bordered p="$4" gap="$4">
<Text fontSize="$8">Login</Text>
<ErrorAlert error={login.error} />
<FormField vertical label="Username/Email">
<InputField form={form} name="username" />
</FormField>
<FormField vertical label="Password">
<InputField form={form} name="password" secureTextEntry />
</FormField>
<Separator />
<Button
icon={<Icons name="lock" size={16} />}
onPress={onSubmit}
isLoading={login.isPending}
>
Connect
</Button>
<Button onPress={() => router.push("/server")} bg="$colorTransparent">
Change Server
</Button>
</Card>
</ScrollView>
</>
);
}

View File

@ -0,0 +1,10 @@
import { z } from "zod";
export const loginSchema = z.object({
username: z.string(),
password: z.string(),
});
export const loginResultSchema = z.object({
sessionId: z.string().min(40),
});

View File

@ -0,0 +1,78 @@
import { View, Text, ScrollView, Card } from "tamagui";
import React from "react";
import FormField from "@/components/ui/form";
import { InputField } from "@/components/ui/input";
import { useZForm } from "@/hooks/useZForm";
import { getServerResultSchema, serverSchema } from "./schema";
import { router, Stack } from "expo-router";
import Button from "@/components/ui/button";
import ThemeSwitcher from "@/components/containers/theme-switcher";
import { useMutation } from "@tanstack/react-query";
import { ofetch } from "ofetch";
import { z } from "zod";
import { ErrorAlert } from "@/components/ui/alert";
import { addServer } from "@/stores/app";
export default function ServerPage() {
const form = useZForm(serverSchema);
const serverConnect = useMutation({
mutationFn: async (body: z.infer<typeof serverSchema>) => {
const res = await ofetch(body.url + "/server");
const { data } = getServerResultSchema.safeParse(res);
if (!data) {
throw new Error("Invalid server");
}
return data;
},
onSuccess(data, payload) {
addServer({ url: payload.url, name: data.name }, true);
router.replace("/auth/login");
},
});
const onSubmit = form.handleSubmit((values) => {
serverConnect.mutate(values);
});
return (
<>
<Stack.Screen
options={{
contentStyle: {
width: "100%",
maxWidth: 600,
marginHorizontal: "auto",
},
title: "Vaulterm",
headerRight: () => (
<ThemeSwitcher bg="$colorTransparent" $gtSm={{ mr: "$3" }} />
),
}}
/>
<ScrollView
contentContainerStyle={{
padding: "$4",
pb: "$12",
justifyContent: "center",
flexGrow: 1,
}}
>
<Card bordered p="$4" gap="$4">
<Text fontSize="$8">Connect to Server</Text>
<ErrorAlert error={serverConnect.error} />
<FormField vertical label="URL">
<InputField form={form} name="url" placeholder="https://" />
</FormField>
<Button onPress={onSubmit} isLoading={serverConnect.isPending}>
Connect
</Button>
</Card>
</ScrollView>
</>
);
}

View File

@ -0,0 +1,10 @@
import { z } from "zod";
export const serverSchema = z.object({
url: z.string().url("Invalid URL"),
});
export const getServerResultSchema = z.object({
name: z.string(),
version: z.string().min(1),
});

62
frontend/stores/app.ts Normal file
View File

@ -0,0 +1,62 @@
import { createStore, useStore } from "zustand";
import { persist, createJSONStorage } from "zustand/middleware";
import AsyncStorage from "@react-native-async-storage/async-storage";
type AppServer = {
name?: string;
url: string;
};
type AppStore = {
servers: AppServer[];
curServerIdx?: number | null;
};
const appStore = createStore(
persist<AppStore>(
() => ({
servers: [],
curServerIdx: null,
}),
{
name: "vaulterm:app",
storage: createJSONStorage(() => AsyncStorage),
}
)
);
export function addServer(srv: AppServer, setActive?: boolean) {
const curServers = appStore.getState().servers;
const isExist = curServers.findIndex((s) => s.url === srv.url);
if (isExist >= 0) {
setActiveServer(isExist);
return;
}
appStore.setState((state) => ({
servers: [...state.servers, srv],
curServerIdx: setActive ? state.servers.length : state.curServerIdx,
}));
}
export function removeServer(idx: number) {
appStore.setState((state) => ({
servers: state.servers.filter((_, i) => i !== idx),
curServerIdx: state.curServerIdx === idx ? null : state.curServerIdx,
}));
}
export function setActiveServer(idx: number) {
appStore.setState({ curServerIdx: idx });
}
export const useAppStore = () => {
const state = useStore(appStore);
const curServer =
state.curServerIdx != null ? state.servers[state.curServerIdx] : null;
return { ...state, curServer };
};
export default appStore;

View File

@ -12,7 +12,7 @@ const authStore = createStore(
token: null, token: null,
}), }),
{ {
name: "auth", name: "vaulterm:auth",
storage: createJSONStorage(() => AsyncStorage), storage: createJSONStorage(() => AsyncStorage),
} }
) )

View File

@ -38,6 +38,9 @@ export const useTermSession = create(
set({ curSession: idx }); set({ curSession: idx });
}, },
}), }),
{ name: "term-sessions", storage: createJSONStorage(() => AsyncStorage) } {
name: "vaulterm:term-sessions",
storage: createJSONStorage(() => AsyncStorage),
}
) )
); );

View File

@ -20,7 +20,7 @@ const useThemeStore = create(
}, },
}), }),
{ {
name: "theme", name: "vaulterm:theme",
storage: createJSONStorage(() => AsyncStorage), storage: createJSONStorage(() => AsyncStorage),
} }
) )

View File

@ -5,10 +5,8 @@ import (
"github.com/gofiber/fiber/v2/middleware/cors" "github.com/gofiber/fiber/v2/middleware/cors"
"github.com/joho/godotenv" "github.com/joho/godotenv"
"rul.sh/vaulterm/app/auth" "rul.sh/vaulterm/app/auth"
"rul.sh/vaulterm/app/hosts"
"rul.sh/vaulterm/app/keychains"
"rul.sh/vaulterm/app/ws"
"rul.sh/vaulterm/db" "rul.sh/vaulterm/db"
"rul.sh/vaulterm/middleware"
) )
func NewApp() *fiber.App { func NewApp() *fiber.App {
@ -18,20 +16,26 @@ func NewApp() *fiber.App {
// Create fiber app // Create fiber app
app := fiber.New(fiber.Config{ErrorHandler: ErrorHandler}) app := fiber.New(fiber.Config{ErrorHandler: ErrorHandler})
// Middlewares
app.Use(cors.New()) app.Use(cors.New())
// Init app routes // Server info
auth.Router(app) app.Get("/server", func(c *fiber.Ctx) error {
hosts.Router(app) return c.JSON(&fiber.Map{
keychains.Router(app) "name": "Vaulterm",
ws.Router(app) "version": "0.0.1",
})
})
// Health check // Health check
app.Get("/health-check", func(c *fiber.Ctx) error { app.Get("/health-check", func(c *fiber.Ctx) error {
return c.SendString("OK") return c.SendString("OK")
}) })
app.Use(middleware.Auth)
auth.Router(app)
app.Use(middleware.Protected())
InitRouter(app)
return app return app
} }

View File

@ -9,7 +9,7 @@ import (
type Auth struct{ db *gorm.DB } type Auth struct{ db *gorm.DB }
func NewAuthRepository() *Auth { func NewRepository() *Auth {
return &Auth{db: db.Get()} return &Auth{db: db.Get()}
} }

View File

@ -1,10 +1,9 @@
package auth package auth
import ( import (
"strings"
"github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2"
"rul.sh/vaulterm/lib" "rul.sh/vaulterm/lib"
"rul.sh/vaulterm/middleware"
"rul.sh/vaulterm/utils" "rul.sh/vaulterm/utils"
) )
@ -12,12 +11,12 @@ func Router(app *fiber.App) {
router := app.Group("/auth") router := app.Group("/auth")
router.Post("/login", login) router.Post("/login", login)
router.Get("/user", getUser) router.Get("/user", middleware.Protected(), getUser)
router.Post("/logout", logout) router.Post("/logout", middleware.Protected(), logout)
} }
func login(c *fiber.Ctx) error { func login(c *fiber.Ctx) error {
repo := NewAuthRepository() repo := NewRepository()
var body LoginSchema var body LoginSchema
if err := c.BodyParser(&body); err != nil { if err := c.BodyParser(&body); err != nil {
@ -54,32 +53,15 @@ func login(c *fiber.Ctx) error {
} }
func getUser(c *fiber.Ctx) error { func getUser(c *fiber.Ctx) error {
auth := c.Get("Authorization") user := utils.GetUser(c)
var sessionId string return c.JSON(user)
if auth != "" {
sessionId = strings.Split(auth, " ")[1]
}
repo := NewAuthRepository()
session, err := repo.GetSession(sessionId)
if err != nil {
return utils.ResponseError(c, err, 500)
}
return c.JSON(session)
} }
func logout(c *fiber.Ctx) error { func logout(c *fiber.Ctx) error {
auth := c.Get("Authorization")
force := c.Query("force") force := c.Query("force")
var sessionId string sessionId := c.Locals("sessionId").(string)
if auth != "" { repo := NewRepository()
sessionId = strings.Split(auth, " ")[1]
}
repo := NewAuthRepository()
err := repo.RemoveUserSession(sessionId, force == "true") err := repo.RemoveUserSession(sessionId, force == "true")
if err != nil { if err != nil {

View File

@ -4,25 +4,36 @@ import (
"gorm.io/gorm" "gorm.io/gorm"
"rul.sh/vaulterm/db" "rul.sh/vaulterm/db"
"rul.sh/vaulterm/models" "rul.sh/vaulterm/models"
"rul.sh/vaulterm/utils"
) )
type Hosts struct{ db *gorm.DB } type Hosts struct {
db *gorm.DB
User *utils.UserContext
}
func NewRepository() *Hosts { func NewRepository(r *Hosts) *Hosts {
return &Hosts{db: db.Get()} if r == nil {
r = &Hosts{}
}
r.db = db.Get()
return r
} }
func (r *Hosts) GetAll() ([]*models.Host, error) { func (r *Hosts) GetAll() ([]*models.Host, error) {
query := r.ACL(r.db.Order("id DESC"))
var rows []*models.Host var rows []*models.Host
ret := r.db.Order("id DESC").Find(&rows) ret := query.Find(&rows)
return rows, ret.Error return rows, ret.Error
} }
func (r *Hosts) Get(id string) (*models.HostDecrypted, error) { func (r *Hosts) Get(id string) (*models.HostDecrypted, error) {
var host models.Host query := r.ACL(r.db)
ret := r.db.Joins("Key").Joins("AltKey").Where("hosts.id = ?", id).First(&host)
var host models.Host
ret := query.Joins("Key").Joins("AltKey").Where("hosts.id = ?", id).First(&host)
if ret.Error != nil { if ret.Error != nil {
return nil, ret.Error return nil, ret.Error
} }
@ -37,12 +48,13 @@ func (r *Hosts) Get(id string) (*models.HostDecrypted, error) {
func (r *Hosts) Exists(id string) (bool, error) { func (r *Hosts) Exists(id string) (bool, error) {
var count int64 var count int64
ret := r.db.Model(&models.Host{}).Where("id = ?", id).Count(&count) ret := r.ACL(r.db.Model(&models.Host{}).Where("id = ?", id)).Count(&count)
return count > 0, ret.Error return count > 0, ret.Error
} }
func (r *Hosts) Delete(id string) error { func (r *Hosts) Delete(id string) error {
return r.db.Delete(&models.Host{Model: models.Model{ID: id}}).Error query := r.ACL(r.db)
return query.Delete(&models.Host{Model: models.Model{ID: id}}).Error
} }
func (r *Hosts) Create(item *models.Host) error { func (r *Hosts) Create(item *models.Host) error {
@ -50,5 +62,15 @@ func (r *Hosts) Create(item *models.Host) error {
} }
func (r *Hosts) Update(id string, item *models.Host) error { func (r *Hosts) Update(id string, item *models.Host) error {
return r.db.Where("id = ?", id).Updates(item).Error query := r.ACL(r.db.Where("id = ?", id))
return query.Updates(item).Error
}
func (r *Hosts) ACL(query *gorm.DB) *gorm.DB {
if r.User.IsAdmin {
return query
}
return query.Where("hosts.owner_id = ?", r.User.ID)
} }

View File

@ -9,7 +9,7 @@ import (
"rul.sh/vaulterm/utils" "rul.sh/vaulterm/utils"
) )
func Router(app *fiber.App) { func Router(app fiber.Router) {
router := app.Group("/hosts") router := app.Group("/hosts")
router.Get("/", getAll) router.Get("/", getAll)
@ -19,7 +19,9 @@ func Router(app *fiber.App) {
} }
func getAll(c *fiber.Ctx) error { func getAll(c *fiber.Ctx) error {
repo := NewRepository() user := utils.GetUser(c)
repo := NewRepository(&Hosts{User: user})
rows, err := repo.GetAll() rows, err := repo.GetAll()
if err != nil { if err != nil {
return utils.ResponseError(c, err, 500) return utils.ResponseError(c, err, 500)
@ -36,8 +38,11 @@ func create(c *fiber.Ctx) error {
return utils.ResponseError(c, err, 500) return utils.ResponseError(c, err, 500)
} }
repo := NewRepository() user := utils.GetUser(c)
repo := NewRepository(&Hosts{User: user})
item := &models.Host{ item := &models.Host{
OwnerID: user.ID,
Type: body.Type, Type: body.Type,
Label: body.Label, Label: body.Label,
Host: body.Host, Host: body.Host,
@ -48,7 +53,7 @@ func create(c *fiber.Ctx) error {
AltKeyID: body.AltKeyID, AltKeyID: body.AltKeyID,
} }
osName, err := tryConnect(item) osName, err := tryConnect(c, item)
if err != nil { if err != nil {
return utils.ResponseError(c, fmt.Errorf("cannot connect to the host: %s", err), 500) return utils.ResponseError(c, fmt.Errorf("cannot connect to the host: %s", err), 500)
} }
@ -67,7 +72,8 @@ func update(c *fiber.Ctx) error {
return utils.ResponseError(c, err, 500) return utils.ResponseError(c, err, 500)
} }
repo := NewRepository() user := utils.GetUser(c)
repo := NewRepository(&Hosts{User: user})
id := c.Params("id") id := c.Params("id")
exist, _ := repo.Exists(id) exist, _ := repo.Exists(id)
@ -87,7 +93,7 @@ func update(c *fiber.Ctx) error {
AltKeyID: body.AltKeyID, AltKeyID: body.AltKeyID,
} }
osName, err := tryConnect(item) osName, err := tryConnect(c, item)
if err != nil { if err != nil {
return utils.ResponseError(c, fmt.Errorf("cannot connect to the host: %s", err), 500) return utils.ResponseError(c, fmt.Errorf("cannot connect to the host: %s", err), 500)
} }
@ -101,7 +107,8 @@ func update(c *fiber.Ctx) error {
} }
func delete(c *fiber.Ctx) error { func delete(c *fiber.Ctx) error {
repo := NewRepository() user := utils.GetUser(c)
repo := NewRepository(&Hosts{User: user})
id := c.Params("id") id := c.Params("id")
exist, _ := repo.Exists(id) exist, _ := repo.Exists(id)

View File

@ -3,13 +3,16 @@ package hosts
import ( import (
"fmt" "fmt"
"github.com/gofiber/fiber/v2"
"rul.sh/vaulterm/app/keychains" "rul.sh/vaulterm/app/keychains"
"rul.sh/vaulterm/lib" "rul.sh/vaulterm/lib"
"rul.sh/vaulterm/models" "rul.sh/vaulterm/models"
"rul.sh/vaulterm/utils"
) )
func tryConnect(host *models.Host) (string, error) { func tryConnect(c *fiber.Ctx, host *models.Host) (string, error) {
keyRepo := keychains.NewRepository() user := utils.GetUser(c)
keyRepo := keychains.NewRepository(&keychains.Keychains{User: user})
var key map[string]interface{} var key map[string]interface{}
var altKey map[string]interface{} var altKey map[string]interface{}

View File

@ -4,18 +4,27 @@ import (
"gorm.io/gorm" "gorm.io/gorm"
"rul.sh/vaulterm/db" "rul.sh/vaulterm/db"
"rul.sh/vaulterm/models" "rul.sh/vaulterm/models"
"rul.sh/vaulterm/utils"
) )
type Keychains struct{ db *gorm.DB } type Keychains struct {
db *gorm.DB
User *utils.UserContext
}
func NewRepository() *Keychains { func NewRepository(r *Keychains) *Keychains {
return &Keychains{db: db.Get()} if r == nil {
r = &Keychains{}
}
r.db = db.Get()
return r
} }
func (r *Keychains) GetAll() ([]*models.Keychain, error) { func (r *Keychains) GetAll() ([]*models.Keychain, error) {
var rows []*models.Keychain var rows []*models.Keychain
ret := r.db.Order("created_at DESC").Find(&rows) query := r.ACL(r.db.Order("created_at DESC"))
ret := query.Find(&rows)
return rows, ret.Error return rows, ret.Error
} }
@ -25,7 +34,9 @@ func (r *Keychains) Create(item *models.Keychain) error {
func (r *Keychains) Get(id string) (*models.Keychain, error) { func (r *Keychains) Get(id string) (*models.Keychain, error) {
var keychain models.Keychain var keychain models.Keychain
if err := r.db.Where("id = ?", id).First(&keychain).Error; err != nil { query := r.ACL(r.db.Where("id = ?", id))
if err := query.First(&keychain).Error; err != nil {
return nil, err return nil, err
} }
@ -34,7 +45,8 @@ func (r *Keychains) Get(id string) (*models.Keychain, error) {
func (r *Keychains) Exists(id string) (bool, error) { func (r *Keychains) Exists(id string) (bool, error) {
var count int64 var count int64
ret := r.db.Model(&models.Keychain{}).Where("id = ?", id).Count(&count) query := r.ACL(r.db.Model(&models.Keychain{}).Where("id = ?", id))
ret := query.Count(&count)
return count > 0, ret.Error return count > 0, ret.Error
} }
@ -58,5 +70,14 @@ func (r *Keychains) GetDecrypted(id string) (*KeychainDecrypted, error) {
} }
func (r *Keychains) Update(id string, item *models.Keychain) error { func (r *Keychains) Update(id string, item *models.Keychain) error {
return r.db.Where("id = ?", id).Updates(item).Error query := r.ACL(r.db.Where("id = ?", id))
return query.Updates(item).Error
}
func (r *Keychains) ACL(query *gorm.DB) *gorm.DB {
if r.User.IsAdmin {
return query
}
return query.Where("keychains.owner_id = ?", r.User.ID)
} }

View File

@ -9,7 +9,7 @@ import (
"rul.sh/vaulterm/utils" "rul.sh/vaulterm/utils"
) )
func Router(app *fiber.App) { func Router(app fiber.Router) {
router := app.Group("/keychains") router := app.Group("/keychains")
router.Get("/", getAll) router.Get("/", getAll)
@ -25,7 +25,9 @@ type GetAllResult struct {
func getAll(c *fiber.Ctx) error { func getAll(c *fiber.Ctx) error {
withData := c.Query("withData") withData := c.Query("withData")
repo := NewRepository() user := utils.GetUser(c)
repo := NewRepository(&Keychains{User: user})
rows, err := repo.GetAll() rows, err := repo.GetAll()
if err != nil { if err != nil {
return utils.ResponseError(c, err, 500) return utils.ResponseError(c, err, 500)
@ -62,9 +64,11 @@ func create(c *fiber.Ctx) error {
return utils.ResponseError(c, err, 500) return utils.ResponseError(c, err, 500)
} }
repo := NewRepository() user := utils.GetUser(c)
repo := NewRepository(&Keychains{User: user})
item := &models.Keychain{ item := &models.Keychain{
OwnerID: user.ID,
Type: body.Type, Type: body.Type,
Label: body.Label, Label: body.Label,
} }
@ -86,7 +90,9 @@ func update(c *fiber.Ctx) error {
return utils.ResponseError(c, err, 500) return utils.ResponseError(c, err, 500)
} }
repo := NewRepository() user := utils.GetUser(c)
repo := NewRepository(&Keychains{User: user})
id := c.Params("id") id := c.Params("id")
exist, _ := repo.Exists(id) exist, _ := repo.Exists(id)

23
server/app/router.go Normal file
View File

@ -0,0 +1,23 @@
package app
import (
"github.com/gofiber/fiber/v2"
"rul.sh/vaulterm/app/hosts"
"rul.sh/vaulterm/app/keychains"
"rul.sh/vaulterm/app/ws"
)
func InitRouter(app *fiber.App) {
// App route list
routes := []Router{
hosts.Router,
keychains.Router,
ws.Router,
}
for _, route := range routes {
route(app)
}
}
type Router func(app fiber.Router)

View File

@ -5,7 +5,7 @@ import (
"github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2"
) )
func Router(app *fiber.App) { func Router(app fiber.Router) {
router := app.Group("/ws") router := app.Group("/ws")
router.Use(func(c *fiber.Ctx) error { router.Use(func(c *fiber.Ctx) error {

View File

@ -13,7 +13,8 @@ import (
func HandleTerm(c *websocket.Conn) { func HandleTerm(c *websocket.Conn) {
hostId := c.Query("hostId") hostId := c.Query("hostId")
hostRepo := hosts.NewRepository() user := utils.GetUserWs(c)
hostRepo := hosts.NewRepository(&hosts.Hosts{User: user})
data, err := hostRepo.Get(hostId) data, err := hostRepo.Get(hostId)
if data == nil { if data == nil {

51
server/middleware/auth.go Normal file
View File

@ -0,0 +1,51 @@
package middleware
import (
"strings"
"github.com/gofiber/fiber/v2"
"rul.sh/vaulterm/db"
"rul.sh/vaulterm/models"
)
func Auth(c *fiber.Ctx) error {
authHeader := c.Get("Authorization")
var sessionId string
if authHeader != "" {
sessionId = strings.Split(authHeader, " ")[1]
}
if strings.HasPrefix(c.Path(), "/ws") && sessionId == "" {
sessionId = c.Query("sid")
}
session, _ := GetUserSession(sessionId)
if session != nil && session.User.ID != "" {
c.Locals("user", &session.User)
c.Locals("sessionId", sessionId)
}
return c.Next()
}
func GetUserSession(sessionId string) (*models.UserSession, error) {
var session models.UserSession
res := db.Get().Joins("User").Where("user_sessions.id = ?", sessionId).First(&session)
return &session, res.Error
}
func Protected() func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
user := c.Locals("user")
if user == nil {
return &fiber.Error{
Code: fiber.StatusUnauthorized,
Message: "Unauthorized",
}
}
return c.Next()
}
}

View File

@ -14,6 +14,9 @@ const (
type Host struct { type Host struct {
Model Model
OwnerID string `json:"userId" gorm:"index:hosts_owner_id_idx;type:varchar(26)"`
Owner User `json:"user" gorm:"foreignKey:OwnerID"`
Type string `json:"type" gorm:"not null;index:hosts_type_idx;type:varchar(16)"` Type string `json:"type" gorm:"not null;index:hosts_type_idx;type:varchar(16)"`
Label string `json:"label"` Label string `json:"label"`
Host string `json:"host" gorm:"type:varchar(64)"` Host string `json:"host" gorm:"type:varchar(64)"`
@ -54,3 +57,14 @@ func (h *Host) DecryptKeys() (*HostDecrypted, error) {
return res, nil return res, nil
} }
type HostHasAccessOptions struct {
UserID string
}
func (h *Host) HasAccess(o HostHasAccessOptions) bool {
if o.UserID == h.OwnerID {
return true
}
return false
}

View File

@ -16,6 +16,9 @@ const (
type Keychain struct { type Keychain struct {
Model Model
OwnerID string `json:"userId" gorm:"index:hosts_owner_id_idx;type:varchar(26)"`
Owner User `json:"user" gorm:"foreignKey:OwnerID"`
Label string `json:"label"` Label string `json:"label"`
Type string `json:"type" gorm:"not null;index:keychains_type_idx;type:varchar(12)"` Type string `json:"type" gorm:"not null;index:keychains_type_idx;type:varchar(12)"`
Data string `json:"-" gorm:"type:text"` Data string `json:"-" gorm:"type:text"`

35
server/utils/context.go Normal file
View File

@ -0,0 +1,35 @@
package utils
import (
"github.com/gofiber/contrib/websocket"
"github.com/gofiber/fiber/v2"
"rul.sh/vaulterm/models"
)
type UserContext struct {
*models.User
IsAdmin bool
}
func getUserData(user *models.User) *UserContext {
isAdmin := false
if user.Role == models.UserRoleAdmin {
isAdmin = true
}
return &UserContext{
User: user,
IsAdmin: isAdmin,
}
}
func GetUser(c *fiber.Ctx) *UserContext {
user := c.Locals("user").(*models.User)
return getUserData(user)
}
func GetUserWs(c *websocket.Conn) *UserContext {
user := c.Locals("user").(*models.User)
return getUserData(user)
}