diff --git a/frontend/app/_providers.tsx b/frontend/app/_providers.tsx
index 15dc0d3..56da1e1 100644
--- a/frontend/app/_providers.tsx
+++ b/frontend/app/_providers.tsx
@@ -12,6 +12,7 @@ import { router, usePathname, useRootNavigationState } from "expo-router";
import { useAuthStore } from "@/stores/auth";
import { PortalProvider } from "tamagui";
import { queryClient } from "@/lib/api";
+import { useAppStore } from "@/stores/app";
type Props = PropsWithChildren;
@@ -53,16 +54,24 @@ const AuthProvider = () => {
const pathname = usePathname();
const rootNavigationState = useRootNavigationState();
const { isLoggedIn } = useAuthStore();
+ const { curServer } = useAppStore();
useEffect(() => {
if (!rootNavigationState?.key) {
return;
}
- if (!pathname.startsWith("/auth") && !isLoggedIn) {
+ if (!curServer && !pathname.startsWith("/server")) {
+ router.replace("/server");
+ return;
+ }
+
+ const isProtected = !["/auth", "/server"].find((path) =>
+ pathname.startsWith(path)
+ );
+
+ if (isProtected && !isLoggedIn) {
router.replace("/auth/login");
- } else if (pathname.startsWith("/auth") && isLoggedIn) {
- router.replace("/");
}
}, [pathname, rootNavigationState, isLoggedIn]);
diff --git a/frontend/app/auth/login.tsx b/frontend/app/auth/login.tsx
index 0a0da13..7c556ca 100644
--- a/frontend/app/auth/login.tsx
+++ b/frontend/app/auth/login.tsx
@@ -1,14 +1,3 @@
-import { View, Text, Button } from "tamagui";
-import React from "react";
-import authStore from "@/stores/auth";
+import LoginPage from "@/pages/auth/login";
-export default function LoginPage() {
- return (
-
- LoginPage
-
-
- );
-}
+export default LoginPage;
diff --git a/frontend/app/index.tsx b/frontend/app/index.tsx
index fac85db..0b7d8d9 100644
--- a/frontend/app/index.tsx
+++ b/frontend/app/index.tsx
@@ -1,9 +1,15 @@
import React from "react";
import { Redirect } from "expo-router";
import { useTermSession } from "@/stores/terminal-sessions";
+import { useAppStore } from "@/stores/app";
export default function index() {
const { sessions, curSession } = useTermSession();
+ const { servers, curServer } = useAppStore();
+
+ if (!servers.length || !curServer) {
+ return ;
+ }
return (
{
- const query = new URLSearchParams(params);
+ const { token } = useAuthStore();
+ const query = new URLSearchParams({ ...params, sid: token || "" });
const url = `${BASE_WS_URL}/ws/term?${query}`;
switch (type) {
diff --git a/frontend/components/containers/theme-switcher.tsx b/frontend/components/containers/theme-switcher.tsx
new file mode 100644
index 0000000..fbd5871
--- /dev/null
+++ b/frontend/components/containers/theme-switcher.tsx
@@ -0,0 +1,29 @@
+import React from "react";
+import { Button, GetProps } from "tamagui";
+import Icons from "../ui/icons";
+import useThemeStore from "@/stores/theme";
+
+type Props = GetProps & {
+ iconSize?: number;
+};
+
+const ThemeSwitcher = ({ iconSize = 24, ...props }: Props) => {
+ const { theme, toggle } = useThemeStore();
+
+ return (
+
+ }
+ onPress={toggle}
+ {...props}
+ />
+ );
+};
+
+export default ThemeSwitcher;
diff --git a/frontend/components/ui/form.tsx b/frontend/components/ui/form.tsx
index c436c64..905dde9 100644
--- a/frontend/components/ui/form.tsx
+++ b/frontend/components/ui/form.tsx
@@ -4,11 +4,21 @@ import { Label, Text, View, XStack } from "tamagui";
type FormFieldProps = ComponentPropsWithoutRef & {
label?: string;
htmlFor?: string;
+ vertical?: boolean;
};
-const FormField = ({ label, htmlFor, ...props }: FormFieldProps) => {
+const FormField = ({
+ label,
+ htmlFor,
+ vertical = false,
+ ...props
+}: FormFieldProps) => {
return (
-
+
diff --git a/frontend/lib/api.ts b/frontend/lib/api.ts
index 3f59476..e5975cf 100644
--- a/frontend/lib/api.ts
+++ b/frontend/lib/api.ts
@@ -1,3 +1,4 @@
+import authStore from "@/stores/auth";
import { QueryClient } from "@tanstack/react-query";
import { ofetch } from "ofetch";
@@ -6,7 +7,18 @@ export const BASE_WS_URL = BASE_API_URL.replace("http", "ws");
const api = ofetch.create({
baseURL: BASE_API_URL,
+ onRequest: (config) => {
+ const authToken = authStore.getState().token;
+ if (authToken) {
+ config.options.headers.set("Authorization", `Bearer ${authToken}`);
+ }
+ },
onResponseError: (error) => {
+ if (error.response.status === 401 && !!authStore.getState().token) {
+ authStore.setState({ token: null });
+ throw new Error("Unauthorized");
+ }
+
if (error.response._data) {
const message = error.response._data.message;
throw new Error(message || "Something went wrong");
diff --git a/frontend/pages/auth/login.tsx b/frontend/pages/auth/login.tsx
new file mode 100644
index 0000000..dbab297
--- /dev/null
+++ b/frontend/pages/auth/login.tsx
@@ -0,0 +1,92 @@
+import { Text, ScrollView, Card, Separator } from "tamagui";
+import React from "react";
+import FormField from "@/components/ui/form";
+import { InputField } from "@/components/ui/input";
+import { useZForm } from "@/hooks/useZForm";
+import { router, Stack } from "expo-router";
+import Button from "@/components/ui/button";
+import ThemeSwitcher from "@/components/containers/theme-switcher";
+import { useMutation } from "@tanstack/react-query";
+import { z } from "zod";
+import { ErrorAlert } from "@/components/ui/alert";
+import { loginResultSchema, loginSchema } from "./schema";
+import api from "@/lib/api";
+import Icons from "@/components/ui/icons";
+import authStore from "@/stores/auth";
+
+export default function LoginPage() {
+ const form = useZForm(loginSchema);
+
+ const login = useMutation({
+ mutationFn: async (body: z.infer) => {
+ const res = await api("/auth/login", { method: "POST", body });
+ const { data } = loginResultSchema.safeParse(res);
+ if (!data) {
+ throw new Error("Invalid response!");
+ }
+ return data;
+ },
+ onSuccess(data) {
+ authStore.setState({ token: data.sessionId });
+ router.replace("/");
+ },
+ });
+
+ const onSubmit = form.handleSubmit((values) => {
+ login.mutate(values);
+ });
+
+ return (
+ <>
+ (
+
+ ),
+ }}
+ />
+
+
+
+ Login
+
+
+
+
+
+
+
+
+
+
+
+
+ }
+ onPress={onSubmit}
+ isLoading={login.isPending}
+ >
+ Connect
+
+
+
+
+
+ >
+ );
+}
diff --git a/frontend/pages/auth/schema.ts b/frontend/pages/auth/schema.ts
new file mode 100644
index 0000000..9e5daf9
--- /dev/null
+++ b/frontend/pages/auth/schema.ts
@@ -0,0 +1,10 @@
+import { z } from "zod";
+
+export const loginSchema = z.object({
+ username: z.string(),
+ password: z.string(),
+});
+
+export const loginResultSchema = z.object({
+ sessionId: z.string().min(40),
+});
diff --git a/frontend/pages/server/page.tsx b/frontend/pages/server/page.tsx
new file mode 100644
index 0000000..0baa84b
--- /dev/null
+++ b/frontend/pages/server/page.tsx
@@ -0,0 +1,78 @@
+import { View, Text, ScrollView, Card } from "tamagui";
+import React from "react";
+import FormField from "@/components/ui/form";
+import { InputField } from "@/components/ui/input";
+import { useZForm } from "@/hooks/useZForm";
+import { getServerResultSchema, serverSchema } from "./schema";
+import { router, Stack } from "expo-router";
+import Button from "@/components/ui/button";
+import ThemeSwitcher from "@/components/containers/theme-switcher";
+import { useMutation } from "@tanstack/react-query";
+import { ofetch } from "ofetch";
+import { z } from "zod";
+import { ErrorAlert } from "@/components/ui/alert";
+import { addServer } from "@/stores/app";
+
+export default function ServerPage() {
+ const form = useZForm(serverSchema);
+
+ const serverConnect = useMutation({
+ mutationFn: async (body: z.infer) => {
+ const res = await ofetch(body.url + "/server");
+ const { data } = getServerResultSchema.safeParse(res);
+ if (!data) {
+ throw new Error("Invalid server");
+ }
+ return data;
+ },
+ onSuccess(data, payload) {
+ addServer({ url: payload.url, name: data.name }, true);
+ router.replace("/auth/login");
+ },
+ });
+
+ const onSubmit = form.handleSubmit((values) => {
+ serverConnect.mutate(values);
+ });
+
+ return (
+ <>
+ (
+
+ ),
+ }}
+ />
+
+
+
+ Connect to Server
+
+
+
+
+
+
+
+
+
+
+ >
+ );
+}
diff --git a/frontend/pages/server/schema.ts b/frontend/pages/server/schema.ts
new file mode 100644
index 0000000..b8adceb
--- /dev/null
+++ b/frontend/pages/server/schema.ts
@@ -0,0 +1,10 @@
+import { z } from "zod";
+
+export const serverSchema = z.object({
+ url: z.string().url("Invalid URL"),
+});
+
+export const getServerResultSchema = z.object({
+ name: z.string(),
+ version: z.string().min(1),
+});
diff --git a/frontend/stores/app.ts b/frontend/stores/app.ts
new file mode 100644
index 0000000..ad0fb09
--- /dev/null
+++ b/frontend/stores/app.ts
@@ -0,0 +1,62 @@
+import { createStore, useStore } from "zustand";
+import { persist, createJSONStorage } from "zustand/middleware";
+import AsyncStorage from "@react-native-async-storage/async-storage";
+
+type AppServer = {
+ name?: string;
+ url: string;
+};
+
+type AppStore = {
+ servers: AppServer[];
+ curServerIdx?: number | null;
+};
+
+const appStore = createStore(
+ persist(
+ () => ({
+ servers: [],
+ curServerIdx: null,
+ }),
+ {
+ name: "vaulterm:app",
+ storage: createJSONStorage(() => AsyncStorage),
+ }
+ )
+);
+
+export function addServer(srv: AppServer, setActive?: boolean) {
+ const curServers = appStore.getState().servers;
+ const isExist = curServers.findIndex((s) => s.url === srv.url);
+
+ if (isExist >= 0) {
+ setActiveServer(isExist);
+ return;
+ }
+
+ appStore.setState((state) => ({
+ servers: [...state.servers, srv],
+ curServerIdx: setActive ? state.servers.length : state.curServerIdx,
+ }));
+}
+
+export function removeServer(idx: number) {
+ appStore.setState((state) => ({
+ servers: state.servers.filter((_, i) => i !== idx),
+ curServerIdx: state.curServerIdx === idx ? null : state.curServerIdx,
+ }));
+}
+
+export function setActiveServer(idx: number) {
+ appStore.setState({ curServerIdx: idx });
+}
+
+export const useAppStore = () => {
+ const state = useStore(appStore);
+ const curServer =
+ state.curServerIdx != null ? state.servers[state.curServerIdx] : null;
+
+ return { ...state, curServer };
+};
+
+export default appStore;
diff --git a/frontend/stores/auth.ts b/frontend/stores/auth.ts
index d4facf4..9d812e5 100644
--- a/frontend/stores/auth.ts
+++ b/frontend/stores/auth.ts
@@ -12,7 +12,7 @@ const authStore = createStore(
token: null,
}),
{
- name: "auth",
+ name: "vaulterm:auth",
storage: createJSONStorage(() => AsyncStorage),
}
)
diff --git a/frontend/stores/terminal-sessions.ts b/frontend/stores/terminal-sessions.ts
index 992bb2f..55e47be 100644
--- a/frontend/stores/terminal-sessions.ts
+++ b/frontend/stores/terminal-sessions.ts
@@ -38,6 +38,9 @@ export const useTermSession = create(
set({ curSession: idx });
},
}),
- { name: "term-sessions", storage: createJSONStorage(() => AsyncStorage) }
+ {
+ name: "vaulterm:term-sessions",
+ storage: createJSONStorage(() => AsyncStorage),
+ }
)
);
diff --git a/frontend/stores/theme.ts b/frontend/stores/theme.ts
index 105fef7..9ae291f 100644
--- a/frontend/stores/theme.ts
+++ b/frontend/stores/theme.ts
@@ -20,7 +20,7 @@ const useThemeStore = create(
},
}),
{
- name: "theme",
+ name: "vaulterm:theme",
storage: createJSONStorage(() => AsyncStorage),
}
)
diff --git a/server/app/app.go b/server/app/app.go
index 498e194..3924f2d 100644
--- a/server/app/app.go
+++ b/server/app/app.go
@@ -5,10 +5,8 @@ import (
"github.com/gofiber/fiber/v2/middleware/cors"
"github.com/joho/godotenv"
"rul.sh/vaulterm/app/auth"
- "rul.sh/vaulterm/app/hosts"
- "rul.sh/vaulterm/app/keychains"
- "rul.sh/vaulterm/app/ws"
"rul.sh/vaulterm/db"
+ "rul.sh/vaulterm/middleware"
)
func NewApp() *fiber.App {
@@ -18,20 +16,26 @@ func NewApp() *fiber.App {
// Create fiber app
app := fiber.New(fiber.Config{ErrorHandler: ErrorHandler})
-
- // Middlewares
app.Use(cors.New())
- // Init app routes
- auth.Router(app)
- hosts.Router(app)
- keychains.Router(app)
- ws.Router(app)
+ // Server info
+ app.Get("/server", func(c *fiber.Ctx) error {
+ return c.JSON(&fiber.Map{
+ "name": "Vaulterm",
+ "version": "0.0.1",
+ })
+ })
// Health check
app.Get("/health-check", func(c *fiber.Ctx) error {
return c.SendString("OK")
})
+ app.Use(middleware.Auth)
+ auth.Router(app)
+
+ app.Use(middleware.Protected())
+ InitRouter(app)
+
return app
}
diff --git a/server/app/auth/repository.go b/server/app/auth/repository.go
index 4c3c119..df3cf52 100644
--- a/server/app/auth/repository.go
+++ b/server/app/auth/repository.go
@@ -9,7 +9,7 @@ import (
type Auth struct{ db *gorm.DB }
-func NewAuthRepository() *Auth {
+func NewRepository() *Auth {
return &Auth{db: db.Get()}
}
diff --git a/server/app/auth/router.go b/server/app/auth/router.go
index 823d4f9..fbb6352 100644
--- a/server/app/auth/router.go
+++ b/server/app/auth/router.go
@@ -1,10 +1,9 @@
package auth
import (
- "strings"
-
"github.com/gofiber/fiber/v2"
"rul.sh/vaulterm/lib"
+ "rul.sh/vaulterm/middleware"
"rul.sh/vaulterm/utils"
)
@@ -12,12 +11,12 @@ func Router(app *fiber.App) {
router := app.Group("/auth")
router.Post("/login", login)
- router.Get("/user", getUser)
- router.Post("/logout", logout)
+ router.Get("/user", middleware.Protected(), getUser)
+ router.Post("/logout", middleware.Protected(), logout)
}
func login(c *fiber.Ctx) error {
- repo := NewAuthRepository()
+ repo := NewRepository()
var body LoginSchema
if err := c.BodyParser(&body); err != nil {
@@ -54,32 +53,15 @@ func login(c *fiber.Ctx) error {
}
func getUser(c *fiber.Ctx) error {
- auth := c.Get("Authorization")
- var sessionId string
-
- if auth != "" {
- sessionId = strings.Split(auth, " ")[1]
- }
-
- repo := NewAuthRepository()
- session, err := repo.GetSession(sessionId)
- if err != nil {
- return utils.ResponseError(c, err, 500)
- }
-
- return c.JSON(session)
+ user := utils.GetUser(c)
+ return c.JSON(user)
}
func logout(c *fiber.Ctx) error {
- auth := c.Get("Authorization")
force := c.Query("force")
- var sessionId string
+ sessionId := c.Locals("sessionId").(string)
- if auth != "" {
- sessionId = strings.Split(auth, " ")[1]
- }
-
- repo := NewAuthRepository()
+ repo := NewRepository()
err := repo.RemoveUserSession(sessionId, force == "true")
if err != nil {
diff --git a/server/app/hosts/repository.go b/server/app/hosts/repository.go
index 0b88aa5..6d97aae 100644
--- a/server/app/hosts/repository.go
+++ b/server/app/hosts/repository.go
@@ -4,25 +4,36 @@ import (
"gorm.io/gorm"
"rul.sh/vaulterm/db"
"rul.sh/vaulterm/models"
+ "rul.sh/vaulterm/utils"
)
-type Hosts struct{ db *gorm.DB }
+type Hosts struct {
+ db *gorm.DB
+ User *utils.UserContext
+}
-func NewRepository() *Hosts {
- return &Hosts{db: db.Get()}
+func NewRepository(r *Hosts) *Hosts {
+ if r == nil {
+ r = &Hosts{}
+ }
+ r.db = db.Get()
+ return r
}
func (r *Hosts) GetAll() ([]*models.Host, error) {
+ query := r.ACL(r.db.Order("id DESC"))
+
var rows []*models.Host
- ret := r.db.Order("id DESC").Find(&rows)
+ ret := query.Find(&rows)
return rows, ret.Error
}
func (r *Hosts) Get(id string) (*models.HostDecrypted, error) {
- var host models.Host
- ret := r.db.Joins("Key").Joins("AltKey").Where("hosts.id = ?", id).First(&host)
+ query := r.ACL(r.db)
+ var host models.Host
+ ret := query.Joins("Key").Joins("AltKey").Where("hosts.id = ?", id).First(&host)
if ret.Error != nil {
return nil, ret.Error
}
@@ -37,12 +48,13 @@ func (r *Hosts) Get(id string) (*models.HostDecrypted, error) {
func (r *Hosts) Exists(id string) (bool, error) {
var count int64
- ret := r.db.Model(&models.Host{}).Where("id = ?", id).Count(&count)
+ ret := r.ACL(r.db.Model(&models.Host{}).Where("id = ?", id)).Count(&count)
return count > 0, ret.Error
}
func (r *Hosts) Delete(id string) error {
- return r.db.Delete(&models.Host{Model: models.Model{ID: id}}).Error
+ query := r.ACL(r.db)
+ return query.Delete(&models.Host{Model: models.Model{ID: id}}).Error
}
func (r *Hosts) Create(item *models.Host) error {
@@ -50,5 +62,15 @@ func (r *Hosts) Create(item *models.Host) error {
}
func (r *Hosts) Update(id string, item *models.Host) error {
- return r.db.Where("id = ?", id).Updates(item).Error
+ query := r.ACL(r.db.Where("id = ?", id))
+
+ return query.Updates(item).Error
+}
+
+func (r *Hosts) ACL(query *gorm.DB) *gorm.DB {
+ if r.User.IsAdmin {
+ return query
+ }
+
+ return query.Where("hosts.owner_id = ?", r.User.ID)
}
diff --git a/server/app/hosts/router.go b/server/app/hosts/router.go
index 91b9476..1481c0f 100644
--- a/server/app/hosts/router.go
+++ b/server/app/hosts/router.go
@@ -9,7 +9,7 @@ import (
"rul.sh/vaulterm/utils"
)
-func Router(app *fiber.App) {
+func Router(app fiber.Router) {
router := app.Group("/hosts")
router.Get("/", getAll)
@@ -19,7 +19,9 @@ func Router(app *fiber.App) {
}
func getAll(c *fiber.Ctx) error {
- repo := NewRepository()
+ user := utils.GetUser(c)
+ repo := NewRepository(&Hosts{User: user})
+
rows, err := repo.GetAll()
if err != nil {
return utils.ResponseError(c, err, 500)
@@ -36,8 +38,11 @@ func create(c *fiber.Ctx) error {
return utils.ResponseError(c, err, 500)
}
- repo := NewRepository()
+ user := utils.GetUser(c)
+ repo := NewRepository(&Hosts{User: user})
+
item := &models.Host{
+ OwnerID: user.ID,
Type: body.Type,
Label: body.Label,
Host: body.Host,
@@ -48,7 +53,7 @@ func create(c *fiber.Ctx) error {
AltKeyID: body.AltKeyID,
}
- osName, err := tryConnect(item)
+ osName, err := tryConnect(c, item)
if err != nil {
return utils.ResponseError(c, fmt.Errorf("cannot connect to the host: %s", err), 500)
}
@@ -67,7 +72,8 @@ func update(c *fiber.Ctx) error {
return utils.ResponseError(c, err, 500)
}
- repo := NewRepository()
+ user := utils.GetUser(c)
+ repo := NewRepository(&Hosts{User: user})
id := c.Params("id")
exist, _ := repo.Exists(id)
@@ -87,7 +93,7 @@ func update(c *fiber.Ctx) error {
AltKeyID: body.AltKeyID,
}
- osName, err := tryConnect(item)
+ osName, err := tryConnect(c, item)
if err != nil {
return utils.ResponseError(c, fmt.Errorf("cannot connect to the host: %s", err), 500)
}
@@ -101,7 +107,8 @@ func update(c *fiber.Ctx) error {
}
func delete(c *fiber.Ctx) error {
- repo := NewRepository()
+ user := utils.GetUser(c)
+ repo := NewRepository(&Hosts{User: user})
id := c.Params("id")
exist, _ := repo.Exists(id)
diff --git a/server/app/hosts/utils.go b/server/app/hosts/utils.go
index 837edab..d953d7a 100644
--- a/server/app/hosts/utils.go
+++ b/server/app/hosts/utils.go
@@ -3,13 +3,16 @@ package hosts
import (
"fmt"
+ "github.com/gofiber/fiber/v2"
"rul.sh/vaulterm/app/keychains"
"rul.sh/vaulterm/lib"
"rul.sh/vaulterm/models"
+ "rul.sh/vaulterm/utils"
)
-func tryConnect(host *models.Host) (string, error) {
- keyRepo := keychains.NewRepository()
+func tryConnect(c *fiber.Ctx, host *models.Host) (string, error) {
+ user := utils.GetUser(c)
+ keyRepo := keychains.NewRepository(&keychains.Keychains{User: user})
var key map[string]interface{}
var altKey map[string]interface{}
diff --git a/server/app/keychains/repository.go b/server/app/keychains/repository.go
index 69dccc8..5cfe4fc 100644
--- a/server/app/keychains/repository.go
+++ b/server/app/keychains/repository.go
@@ -4,18 +4,27 @@ import (
"gorm.io/gorm"
"rul.sh/vaulterm/db"
"rul.sh/vaulterm/models"
+ "rul.sh/vaulterm/utils"
)
-type Keychains struct{ db *gorm.DB }
+type Keychains struct {
+ db *gorm.DB
+ User *utils.UserContext
+}
-func NewRepository() *Keychains {
- return &Keychains{db: db.Get()}
+func NewRepository(r *Keychains) *Keychains {
+ if r == nil {
+ r = &Keychains{}
+ }
+ r.db = db.Get()
+ return r
}
func (r *Keychains) GetAll() ([]*models.Keychain, error) {
var rows []*models.Keychain
- ret := r.db.Order("created_at DESC").Find(&rows)
+ query := r.ACL(r.db.Order("created_at DESC"))
+ ret := query.Find(&rows)
return rows, ret.Error
}
@@ -25,7 +34,9 @@ func (r *Keychains) Create(item *models.Keychain) error {
func (r *Keychains) Get(id string) (*models.Keychain, error) {
var keychain models.Keychain
- if err := r.db.Where("id = ?", id).First(&keychain).Error; err != nil {
+ query := r.ACL(r.db.Where("id = ?", id))
+
+ if err := query.First(&keychain).Error; err != nil {
return nil, err
}
@@ -34,7 +45,8 @@ func (r *Keychains) Get(id string) (*models.Keychain, error) {
func (r *Keychains) Exists(id string) (bool, error) {
var count int64
- ret := r.db.Model(&models.Keychain{}).Where("id = ?", id).Count(&count)
+ query := r.ACL(r.db.Model(&models.Keychain{}).Where("id = ?", id))
+ ret := query.Count(&count)
return count > 0, ret.Error
}
@@ -58,5 +70,14 @@ func (r *Keychains) GetDecrypted(id string) (*KeychainDecrypted, error) {
}
func (r *Keychains) Update(id string, item *models.Keychain) error {
- return r.db.Where("id = ?", id).Updates(item).Error
+ query := r.ACL(r.db.Where("id = ?", id))
+ return query.Updates(item).Error
+}
+
+func (r *Keychains) ACL(query *gorm.DB) *gorm.DB {
+ if r.User.IsAdmin {
+ return query
+ }
+
+ return query.Where("keychains.owner_id = ?", r.User.ID)
}
diff --git a/server/app/keychains/router.go b/server/app/keychains/router.go
index 9f1e543..8e200fa 100644
--- a/server/app/keychains/router.go
+++ b/server/app/keychains/router.go
@@ -9,7 +9,7 @@ import (
"rul.sh/vaulterm/utils"
)
-func Router(app *fiber.App) {
+func Router(app fiber.Router) {
router := app.Group("/keychains")
router.Get("/", getAll)
@@ -25,7 +25,9 @@ type GetAllResult struct {
func getAll(c *fiber.Ctx) error {
withData := c.Query("withData")
- repo := NewRepository()
+ user := utils.GetUser(c)
+ repo := NewRepository(&Keychains{User: user})
+
rows, err := repo.GetAll()
if err != nil {
return utils.ResponseError(c, err, 500)
@@ -62,11 +64,13 @@ func create(c *fiber.Ctx) error {
return utils.ResponseError(c, err, 500)
}
- repo := NewRepository()
+ user := utils.GetUser(c)
+ repo := NewRepository(&Keychains{User: user})
item := &models.Keychain{
- Type: body.Type,
- Label: body.Label,
+ OwnerID: user.ID,
+ Type: body.Type,
+ Label: body.Label,
}
if err := item.EncryptData(body.Data); err != nil {
@@ -86,7 +90,9 @@ func update(c *fiber.Ctx) error {
return utils.ResponseError(c, err, 500)
}
- repo := NewRepository()
+ user := utils.GetUser(c)
+ repo := NewRepository(&Keychains{User: user})
+
id := c.Params("id")
exist, _ := repo.Exists(id)
diff --git a/server/app/router.go b/server/app/router.go
new file mode 100644
index 0000000..94751ec
--- /dev/null
+++ b/server/app/router.go
@@ -0,0 +1,23 @@
+package app
+
+import (
+ "github.com/gofiber/fiber/v2"
+ "rul.sh/vaulterm/app/hosts"
+ "rul.sh/vaulterm/app/keychains"
+ "rul.sh/vaulterm/app/ws"
+)
+
+func InitRouter(app *fiber.App) {
+ // App route list
+ routes := []Router{
+ hosts.Router,
+ keychains.Router,
+ ws.Router,
+ }
+
+ for _, route := range routes {
+ route(app)
+ }
+}
+
+type Router func(app fiber.Router)
diff --git a/server/app/ws/router.go b/server/app/ws/router.go
index 9891caf..d0753fa 100644
--- a/server/app/ws/router.go
+++ b/server/app/ws/router.go
@@ -5,7 +5,7 @@ import (
"github.com/gofiber/fiber/v2"
)
-func Router(app *fiber.App) {
+func Router(app fiber.Router) {
router := app.Group("/ws")
router.Use(func(c *fiber.Ctx) error {
diff --git a/server/app/ws/term.go b/server/app/ws/term.go
index e82e4f3..f62a005 100644
--- a/server/app/ws/term.go
+++ b/server/app/ws/term.go
@@ -13,7 +13,8 @@ import (
func HandleTerm(c *websocket.Conn) {
hostId := c.Query("hostId")
- hostRepo := hosts.NewRepository()
+ user := utils.GetUserWs(c)
+ hostRepo := hosts.NewRepository(&hosts.Hosts{User: user})
data, err := hostRepo.Get(hostId)
if data == nil {
diff --git a/server/middleware/auth.go b/server/middleware/auth.go
new file mode 100644
index 0000000..559c7e7
--- /dev/null
+++ b/server/middleware/auth.go
@@ -0,0 +1,51 @@
+package middleware
+
+import (
+ "strings"
+
+ "github.com/gofiber/fiber/v2"
+ "rul.sh/vaulterm/db"
+ "rul.sh/vaulterm/models"
+)
+
+func Auth(c *fiber.Ctx) error {
+ authHeader := c.Get("Authorization")
+ var sessionId string
+
+ if authHeader != "" {
+ sessionId = strings.Split(authHeader, " ")[1]
+ }
+
+ if strings.HasPrefix(c.Path(), "/ws") && sessionId == "" {
+ sessionId = c.Query("sid")
+ }
+
+ session, _ := GetUserSession(sessionId)
+
+ if session != nil && session.User.ID != "" {
+ c.Locals("user", &session.User)
+ c.Locals("sessionId", sessionId)
+ }
+
+ return c.Next()
+}
+
+func GetUserSession(sessionId string) (*models.UserSession, error) {
+ var session models.UserSession
+ res := db.Get().Joins("User").Where("user_sessions.id = ?", sessionId).First(&session)
+ return &session, res.Error
+}
+
+func Protected() func(c *fiber.Ctx) error {
+ return func(c *fiber.Ctx) error {
+ user := c.Locals("user")
+ if user == nil {
+ return &fiber.Error{
+ Code: fiber.StatusUnauthorized,
+ Message: "Unauthorized",
+ }
+ }
+ return c.Next()
+ }
+
+}
diff --git a/server/models/host.go b/server/models/host.go
index 404223b..9bdbb80 100644
--- a/server/models/host.go
+++ b/server/models/host.go
@@ -14,6 +14,9 @@ const (
type Host struct {
Model
+ OwnerID string `json:"userId" gorm:"index:hosts_owner_id_idx;type:varchar(26)"`
+ Owner User `json:"user" gorm:"foreignKey:OwnerID"`
+
Type string `json:"type" gorm:"not null;index:hosts_type_idx;type:varchar(16)"`
Label string `json:"label"`
Host string `json:"host" gorm:"type:varchar(64)"`
@@ -54,3 +57,14 @@ func (h *Host) DecryptKeys() (*HostDecrypted, error) {
return res, nil
}
+
+type HostHasAccessOptions struct {
+ UserID string
+}
+
+func (h *Host) HasAccess(o HostHasAccessOptions) bool {
+ if o.UserID == h.OwnerID {
+ return true
+ }
+ return false
+}
diff --git a/server/models/keychain.go b/server/models/keychain.go
index b002e0d..9ae6f17 100644
--- a/server/models/keychain.go
+++ b/server/models/keychain.go
@@ -16,6 +16,9 @@ const (
type Keychain struct {
Model
+ OwnerID string `json:"userId" gorm:"index:hosts_owner_id_idx;type:varchar(26)"`
+ Owner User `json:"user" gorm:"foreignKey:OwnerID"`
+
Label string `json:"label"`
Type string `json:"type" gorm:"not null;index:keychains_type_idx;type:varchar(12)"`
Data string `json:"-" gorm:"type:text"`
diff --git a/server/utils/context.go b/server/utils/context.go
new file mode 100644
index 0000000..e58d5a8
--- /dev/null
+++ b/server/utils/context.go
@@ -0,0 +1,35 @@
+package utils
+
+import (
+ "github.com/gofiber/contrib/websocket"
+ "github.com/gofiber/fiber/v2"
+ "rul.sh/vaulterm/models"
+)
+
+type UserContext struct {
+ *models.User
+ IsAdmin bool
+}
+
+func getUserData(user *models.User) *UserContext {
+ isAdmin := false
+
+ if user.Role == models.UserRoleAdmin {
+ isAdmin = true
+ }
+
+ return &UserContext{
+ User: user,
+ IsAdmin: isAdmin,
+ }
+}
+
+func GetUser(c *fiber.Ctx) *UserContext {
+ user := c.Locals("user").(*models.User)
+ return getUserData(user)
+}
+
+func GetUserWs(c *websocket.Conn) *UserContext {
+ user := c.Locals("user").(*models.User)
+ return getUserData(user)
+}