From 313dc74681bdf9c7cd595ee23c984698f57b1ce9 Mon Sep 17 00:00:00 2001 From: Shamik Karkhanis Date: Fri, 3 Apr 2026 10:35:39 -0400 Subject: [PATCH 1/6] feat(routes): relaxed permissions to view rec event and orgs for demo --- internal/database/mocks/Querier.go | 50 ++++++++++++++--------------- internal/handler/events.go | 1 - internal/handler/organizations.go | 1 - internal/router/router.go | 50 ++++++++++++++--------------- internal/router/router_auth_test.go | 28 ++++++++++++++++ 5 files changed, 78 insertions(+), 52 deletions(-) diff --git a/internal/database/mocks/Querier.go b/internal/database/mocks/Querier.go index fa0b713..f0070ab 100644 --- a/internal/database/mocks/Querier.go +++ b/internal/database/mocks/Querier.go @@ -83,22 +83,22 @@ func (_m *Querier) CreateBotToken(ctx context.Context, arg database.CreateBotTok } // CreateEvent provides a mock function with given fields: ctx, arg -func (_m *Querier) CreateEvent(ctx context.Context, arg database.CreateEventParams) (database.Event, error) { +func (_m *Querier) CreateEvent(ctx context.Context, arg database.CreateEventParams) (database.EventsWithOrgID, error) { ret := _m.Called(ctx, arg) if len(ret) == 0 { panic("no return value specified for CreateEvent") } - var r0 database.Event + var r0 database.EventsWithOrgID var r1 error - if rf, ok := ret.Get(0).(func(context.Context, database.CreateEventParams) (database.Event, error)); ok { + if rf, ok := ret.Get(0).(func(context.Context, database.CreateEventParams) (database.EventsWithOrgID, error)); ok { return rf(ctx, arg) } - if rf, ok := ret.Get(0).(func(context.Context, database.CreateEventParams) database.Event); ok { + if rf, ok := ret.Get(0).(func(context.Context, database.CreateEventParams) database.EventsWithOrgID); ok { r0 = rf(ctx, arg) } else { - r0 = ret.Get(0).(database.Event) + r0 = ret.Get(0).(database.EventsWithOrgID) } if rf, ok := ret.Get(1).(func(context.Context, database.CreateEventParams) error); ok { @@ -295,22 +295,22 @@ func (_m *Querier) GetBotTokenByID(ctx context.Context, tokenID uuid.UUID) (data } // GetEventByID provides a mock function with given fields: ctx, eid -func (_m *Querier) GetEventByID(ctx context.Context, eid uuid.UUID) (database.Event, error) { +func (_m *Querier) GetEventByID(ctx context.Context, eid uuid.UUID) (database.EventsWithOrgID, error) { ret := _m.Called(ctx, eid) if len(ret) == 0 { panic("no return value specified for GetEventByID") } - var r0 database.Event + var r0 database.EventsWithOrgID var r1 error - if rf, ok := ret.Get(0).(func(context.Context, uuid.UUID) (database.Event, error)); ok { + if rf, ok := ret.Get(0).(func(context.Context, uuid.UUID) (database.EventsWithOrgID, error)); ok { return rf(ctx, eid) } - if rf, ok := ret.Get(0).(func(context.Context, uuid.UUID) database.Event); ok { + if rf, ok := ret.Get(0).(func(context.Context, uuid.UUID) database.EventsWithOrgID); ok { r0 = rf(ctx, eid) } else { - r0 = ret.Get(0).(database.Event) + r0 = ret.Get(0).(database.EventsWithOrgID) } if rf, ok := ret.Get(1).(func(context.Context, uuid.UUID) error); ok { @@ -697,23 +697,23 @@ func (_m *Querier) ListBotTokens(ctx context.Context) ([]database.ListBotTokensR } // ListEvents provides a mock function with given fields: ctx, arg -func (_m *Querier) ListEvents(ctx context.Context, arg database.ListEventsParams) ([]database.Event, error) { +func (_m *Querier) ListEvents(ctx context.Context, arg database.ListEventsParams) ([]database.EventsWithOrgID, error) { ret := _m.Called(ctx, arg) if len(ret) == 0 { panic("no return value specified for ListEvents") } - var r0 []database.Event + var r0 []database.EventsWithOrgID var r1 error - if rf, ok := ret.Get(0).(func(context.Context, database.ListEventsParams) ([]database.Event, error)); ok { + if rf, ok := ret.Get(0).(func(context.Context, database.ListEventsParams) ([]database.EventsWithOrgID, error)); ok { return rf(ctx, arg) } - if rf, ok := ret.Get(0).(func(context.Context, database.ListEventsParams) []database.Event); ok { + if rf, ok := ret.Get(0).(func(context.Context, database.ListEventsParams) []database.EventsWithOrgID); ok { r0 = rf(ctx, arg) } else { if ret.Get(0) != nil { - r0 = ret.Get(0).([]database.Event) + r0 = ret.Get(0).([]database.EventsWithOrgID) } } @@ -727,23 +727,23 @@ func (_m *Querier) ListEvents(ctx context.Context, arg database.ListEventsParams } // ListEventsByOrg provides a mock function with given fields: ctx, arg -func (_m *Querier) ListEventsByOrg(ctx context.Context, arg database.ListEventsByOrgParams) ([]database.Event, error) { +func (_m *Querier) ListEventsByOrg(ctx context.Context, arg database.ListEventsByOrgParams) ([]database.EventsWithOrgID, error) { ret := _m.Called(ctx, arg) if len(ret) == 0 { panic("no return value specified for ListEventsByOrg") } - var r0 []database.Event + var r0 []database.EventsWithOrgID var r1 error - if rf, ok := ret.Get(0).(func(context.Context, database.ListEventsByOrgParams) ([]database.Event, error)); ok { + if rf, ok := ret.Get(0).(func(context.Context, database.ListEventsByOrgParams) ([]database.EventsWithOrgID, error)); ok { return rf(ctx, arg) } - if rf, ok := ret.Get(0).(func(context.Context, database.ListEventsByOrgParams) []database.Event); ok { + if rf, ok := ret.Get(0).(func(context.Context, database.ListEventsByOrgParams) []database.EventsWithOrgID); ok { r0 = rf(ctx, arg) } else { if ret.Get(0) != nil { - r0 = ret.Get(0).([]database.Event) + r0 = ret.Get(0).([]database.EventsWithOrgID) } } @@ -965,22 +965,22 @@ func (_m *Querier) UpdateBotTokenLastUsed(ctx context.Context, tokenID uuid.UUID } // UpdateEvent provides a mock function with given fields: ctx, arg -func (_m *Querier) UpdateEvent(ctx context.Context, arg database.UpdateEventParams) (database.Event, error) { +func (_m *Querier) UpdateEvent(ctx context.Context, arg database.UpdateEventParams) (database.EventsWithOrgID, error) { ret := _m.Called(ctx, arg) if len(ret) == 0 { panic("no return value specified for UpdateEvent") } - var r0 database.Event + var r0 database.EventsWithOrgID var r1 error - if rf, ok := ret.Get(0).(func(context.Context, database.UpdateEventParams) (database.Event, error)); ok { + if rf, ok := ret.Get(0).(func(context.Context, database.UpdateEventParams) (database.EventsWithOrgID, error)); ok { return rf(ctx, arg) } - if rf, ok := ret.Get(0).(func(context.Context, database.UpdateEventParams) database.Event); ok { + if rf, ok := ret.Get(0).(func(context.Context, database.UpdateEventParams) database.EventsWithOrgID); ok { r0 = rf(ctx, arg) } else { - r0 = ret.Get(0).(database.Event) + r0 = ret.Get(0).(database.EventsWithOrgID) } if rf, ok := ret.Get(1).(func(context.Context, database.UpdateEventParams) error); ok { diff --git a/internal/handler/events.go b/internal/handler/events.go index b237796..afc732d 100644 --- a/internal/handler/events.go +++ b/internal/handler/events.go @@ -21,7 +21,6 @@ import ( // @Param limit query int false "Limit (default 20, max 100)" // @Param offset query int false "Offset (default 0)" // @Success 200 {array} dto.EventResponse -// @Security CookieAuth // @Router /events [get] func (h *Handler) ListEvents(w http.ResponseWriter, r *http.Request) { limit, offset := parsePagination(r) diff --git a/internal/handler/organizations.go b/internal/handler/organizations.go index c6c3607..001af24 100644 --- a/internal/handler/organizations.go +++ b/internal/handler/organizations.go @@ -22,7 +22,6 @@ import ( // @Param limit query int false "Limit (default 20, max 100)" // @Param offset query int false "Offset (default 0)" // @Success 200 {array} dto.OrganizationResponse -// @Security CookieAuth // @Router /organizations [get] func (h *Handler) ListOrganizations(w http.ResponseWriter, r *http.Request) { limit, offset := parsePagination(r) diff --git a/internal/router/router.go b/internal/router/router.go index f70a12b..e4522cc 100644 --- a/internal/router/router.go +++ b/internal/router/router.go @@ -22,31 +22,6 @@ func mountProtectedRoutes(r chi.Router, h *handler.Handler, jwtSecret string) { r.Get("/{uid}/events", h.GetUserEvents) }) - r.Route("/organizations", func(r chi.Router) { - r.Get("/", h.ListOrganizations) - r.Post("/", h.CreateOrganization) - r.Get("/{oid}", h.GetOrganization) - r.Put("/{oid}", h.UpdateOrganization) - r.Delete("/{oid}", h.DeleteOrganization) - r.Get("/{oid}/members", h.ListOrgMembers) - r.Post("/{oid}/members", h.AddOrgMember) - r.Delete("/{oid}/members/{uid}", h.RemoveOrgMember) - r.Get("/{oid}/events", h.ListOrgEvents) - r.Get("/{oid}/links", h.ListOrgLinks) - }) - - r.Route("/events", func(r chi.Router) { - r.Get("/", h.ListEvents) - r.Post("/", h.CreateEvent) - r.Get("/org/{oid}", h.ListEventsByOrg) - r.Get("/{eid}", h.GetEvent) - r.Put("/{eid}", h.UpdateEvent) - r.Delete("/{eid}", h.DeleteEvent) - r.Get("/{eid}/registrations", h.ListEventRegistrations) - r.Post("/{eid}/register", h.RegisterForEvent) - r.Delete("/{eid}/register", h.UnregisterFromEvent) - }) - // Links r.Route("/links", func(r chi.Router) { r.Post("/", h.CreateLink) @@ -61,6 +36,25 @@ func mountProtectedRoutes(r chi.Router, h *handler.Handler, jwtSecret string) { r.Delete("/{token_id}", h.RevokeBotToken) }) }) + + r.With(middleware.Auth(jwtSecret)).Post("/organizations", h.CreateOrganization) + r.With(middleware.Auth(jwtSecret)).Get("/organizations/{oid}", h.GetOrganization) + r.With(middleware.Auth(jwtSecret)).Put("/organizations/{oid}", h.UpdateOrganization) + r.With(middleware.Auth(jwtSecret)).Delete("/organizations/{oid}", h.DeleteOrganization) + r.With(middleware.Auth(jwtSecret)).Get("/organizations/{oid}/members", h.ListOrgMembers) + r.With(middleware.Auth(jwtSecret)).Post("/organizations/{oid}/members", h.AddOrgMember) + r.With(middleware.Auth(jwtSecret)).Delete("/organizations/{oid}/members/{uid}", h.RemoveOrgMember) + r.With(middleware.Auth(jwtSecret)).Get("/organizations/{oid}/events", h.ListOrgEvents) + r.With(middleware.Auth(jwtSecret)).Get("/organizations/{oid}/links", h.ListOrgLinks) + + r.With(middleware.Auth(jwtSecret)).Post("/events", h.CreateEvent) + r.With(middleware.Auth(jwtSecret)).Get("/events/org/{oid}", h.ListEventsByOrg) + r.With(middleware.Auth(jwtSecret)).Get("/events/{eid}", h.GetEvent) + r.With(middleware.Auth(jwtSecret)).Put("/events/{eid}", h.UpdateEvent) + r.With(middleware.Auth(jwtSecret)).Delete("/events/{eid}", h.DeleteEvent) + r.With(middleware.Auth(jwtSecret)).Get("/events/{eid}/registrations", h.ListEventRegistrations) + r.With(middleware.Auth(jwtSecret)).Post("/events/{eid}/register", h.RegisterForEvent) + r.With(middleware.Auth(jwtSecret)).Delete("/events/{eid}/register", h.UnregisterFromEvent) } // New creates a new chi router with all routes configured @@ -107,6 +101,10 @@ func New(h *handler.Handler, queries database.Querier, jwtSecret string, allowed }) }) + // Public read-only collection routes + r.Get("/organizations", h.ListOrganizations) + r.Get("/events", h.ListEvents) + mountProtectedRoutes(r, h, jwtSecret) // Bot routes (M2M auth) @@ -156,6 +154,8 @@ func New(h *handler.Handler, queries database.Querier, jwtSecret string, allowed }) r.Route("/v1", func(r chi.Router) { + r.Get("/organizations", h.ListOrganizations) + r.Get("/events", h.ListEvents) mountProtectedRoutes(r, h, jwtSecret) }) diff --git a/internal/router/router_auth_test.go b/internal/router/router_auth_test.go index 43a359d..30ab6eb 100644 --- a/internal/router/router_auth_test.go +++ b/internal/router/router_auth_test.go @@ -265,6 +265,34 @@ func TestBotTokenManagementUsesDatabaseRole(t *testing.T) { assert.Equal(t, http.StatusOK, res.Code) } +func TestPublicCollectionRoutesDoNotRequireAuth(t *testing.T) { + mockQueries := mocks.NewQuerier(t) + routerUnderTest := newTestRouter(mockQueries) + + mockQueries.On("ListOrganizations", mock.Anything, mock.MatchedBy(func(arg database.ListOrganizationsParams) bool { + return arg.Limit == 20 && arg.Offset == 0 + })).Return([]database.Organization{}, nil).Once() + + mockQueries.On("ListEvents", mock.Anything, mock.MatchedBy(func(arg database.ListEventsParams) bool { + return arg.Limit == 20 && arg.Offset == 0 + })).Return([]database.EventsWithOrgID{}, nil).Once() + + orgReq := httptest.NewRequest(http.MethodGet, "/api/v1/organizations", nil) + orgRes := httptest.NewRecorder() + routerUnderTest.ServeHTTP(orgRes, orgReq) + assert.Equal(t, http.StatusOK, orgRes.Code) + + eventReq := httptest.NewRequest(http.MethodGet, "/api/v1/events", nil) + eventRes := httptest.NewRecorder() + routerUnderTest.ServeHTTP(eventRes, eventReq) + assert.Equal(t, http.StatusOK, eventRes.Code) + + protectedReq := httptest.NewRequest(http.MethodPost, "/api/v1/organizations", bytes.NewBufferString(`{"name":"Still Protected"}`)) + protectedRes := httptest.NewRecorder() + routerUnderTest.ServeHTTP(protectedRes, protectedReq) + assert.Equal(t, http.StatusUnauthorized, protectedRes.Code) +} + func newTestRouter(queries database.Querier) http.Handler { cfg := &config.Config{ Env: "test", From 3691eb96fba7b993898e2d1e9ea1ad7404c0e4b1 Mon Sep 17 00:00:00 2001 From: Shamik Karkhanis Date: Fri, 3 Apr 2026 11:33:28 -0400 Subject: [PATCH 2/6] fix(orgs): users can leave without being admin --- internal/handler/organizations.go | 20 +++++-- internal/handler/organizations_test.go | 78 ++++++++++++++++++++++++++ 2 files changed, 94 insertions(+), 4 deletions(-) diff --git a/internal/handler/organizations.go b/internal/handler/organizations.go index 001af24..4e7e378 100644 --- a/internal/handler/organizations.go +++ b/internal/handler/organizations.go @@ -315,10 +315,6 @@ func (h *Handler) RemoveOrgMember(w http.ResponseWriter, r *http.Request) { return } - if _, ok := h.requireOrgAdmin(w, r, oid); !ok { - return - } - uidStr := chi.URLParam(r, "uid") uid, err := uuid.Parse(uidStr) if err != nil { @@ -326,6 +322,22 @@ func (h *Handler) RemoveOrgMember(w http.ResponseWriter, r *http.Request) { return } + switch middleware.GetAuthType(r.Context()) { + case "bot": + // Bots retain full access to remove members on behalf of users. + default: + authenticatedUID, _, ok := h.requireAuthenticatedUser(w, r) + if !ok { + return + } + + if uid != authenticatedUID { + if _, ok := h.requireOrgAdmin(w, r, oid); !ok { + return + } + } + } + if err := h.queries.RemoveOrgMember(r.Context(), database.RemoveOrgMemberParams{ Uid: uid, Oid: oid, diff --git a/internal/handler/organizations_test.go b/internal/handler/organizations_test.go index e6db509..140fe24 100644 --- a/internal/handler/organizations_test.go +++ b/internal/handler/organizations_test.go @@ -16,6 +16,7 @@ import ( "github.com/capyrpi/api/internal/middleware" "github.com/go-chi/chi/v5" "github.com/google/uuid" + "github.com/jackc/pgx/v5/pgtype" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" ) @@ -157,3 +158,80 @@ func TestAddOrgMemberAllowsSelfJoin(t *testing.T) { assert.Equal(t, http.StatusCreated, rr.Code) } + +func TestRemoveOrgMemberAuthorization(t *testing.T) { + oid := uuid.New() + selfUID := uuid.New() + otherUID := uuid.New() + + tests := []struct { + name string + authUID uuid.UUID + targetUID uuid.UUID + setupMock func(*mocks.Querier) + expectedStatus int + }{ + { + name: "MemberCanRemoveSelf", + authUID: selfUID, + targetUID: selfUID, + setupMock: func(m *mocks.Querier) { + m.On("RemoveOrgMember", mock.Anything, database.RemoveOrgMemberParams{ + Uid: selfUID, + Oid: oid, + }).Return(nil) + }, + expectedStatus: http.StatusNoContent, + }, + { + name: "AdminCanRemoveOtherMember", + authUID: selfUID, + targetUID: otherUID, + setupMock: func(m *mocks.Querier) { + m.On("IsOrgAdmin", mock.Anything, database.IsOrgAdminParams{ + Uid: selfUID, + Oid: oid, + }).Return(pgtype.Bool{Bool: true, Valid: true}, nil) + m.On("RemoveOrgMember", mock.Anything, database.RemoveOrgMemberParams{ + Uid: otherUID, + Oid: oid, + }).Return(nil) + }, + expectedStatus: http.StatusNoContent, + }, + { + name: "NonAdminCannotRemoveOtherMember", + authUID: selfUID, + targetUID: otherUID, + setupMock: func(m *mocks.Querier) { + m.On("IsOrgAdmin", mock.Anything, database.IsOrgAdminParams{ + Uid: selfUID, + Oid: oid, + }).Return(pgtype.Bool{Bool: false, Valid: true}, nil) + }, + expectedStatus: http.StatusForbidden, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockQueries := mocks.NewQuerier(t) + tt.setupMock(mockQueries) + + h := handler.New(mockQueries, &config.Config{}) + + req := httptest.NewRequest(http.MethodDelete, "/organizations/"+oid.String()+"/members/"+tt.targetUID.String(), nil) + req = req.WithContext(context.WithValue(context.Background(), middleware.UserClaimsKey, &middleware.UserClaims{UserID: tt.authUID.String()})) + + rctx := chi.NewRouteContext() + rctx.URLParams.Add("oid", oid.String()) + rctx.URLParams.Add("uid", tt.targetUID.String()) + req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx)) + + rr := httptest.NewRecorder() + http.HandlerFunc(h.RemoveOrgMember).ServeHTTP(rr, req) + + assert.Equal(t, tt.expectedStatus, rr.Code) + }) + } +} From ce8d905f41c2bfa1f58de25e2c505528f9d493a2 Mon Sep 17 00:00:00 2001 From: Shamik Karkhanis Date: Fri, 3 Apr 2026 12:39:42 -0400 Subject: [PATCH 3/6] fixed resource not found issue --- internal/database/querier.go | 2 +- internal/database/queries.sql | 10 +++------- internal/database/queries.sql.go | 15 +++++---------- internal/handler/events.go | 8 +++++++- 4 files changed, 16 insertions(+), 19 deletions(-) diff --git a/internal/database/querier.go b/internal/database/querier.go index ce2f4d6..0ca1acf 100644 --- a/internal/database/querier.go +++ b/internal/database/querier.go @@ -15,7 +15,7 @@ type Querier interface { AddEventHost(ctx context.Context, arg AddEventHostParams) error AddOrgMember(ctx context.Context, arg AddOrgMemberParams) error CreateBotToken(ctx context.Context, arg CreateBotTokenParams) (BotToken, error) - CreateEvent(ctx context.Context, arg CreateEventParams) (EventsWithOrgID, error) + CreateEvent(ctx context.Context, arg CreateEventParams) (Event, error) // Link Queries CreateLink(ctx context.Context, arg CreateLinkParams) (Link, error) CreateOrganization(ctx context.Context, name string) (Organization, error) diff --git a/internal/database/queries.sql b/internal/database/queries.sql index 5159491..c519f1d 100644 --- a/internal/database/queries.sql +++ b/internal/database/queries.sql @@ -87,13 +87,9 @@ ORDER BY e.event_time DESC LIMIT $2 OFFSET $3; -- name: CreateEvent :one -WITH updated AS ( - INSERT INTO events (title, location, event_time, description) - VALUES ($1, $2, $3, $4) - RETURNING * -) -SELECT v.* FROM events_with_org_ids v -WHERE v.eid = (SELECT eid FROM updated); +INSERT INTO events (title, location, event_time, description) +VALUES ($1, $2, $3, $4) +RETURNING *; -- name: UpdateEvent :one WITH updated AS ( diff --git a/internal/database/queries.sql.go b/internal/database/queries.sql.go index 9fbc29d..9d93d34 100644 --- a/internal/database/queries.sql.go +++ b/internal/database/queries.sql.go @@ -80,13 +80,9 @@ func (q *Queries) CreateBotToken(ctx context.Context, arg CreateBotTokenParams) } const createEvent = `-- name: CreateEvent :one -WITH updated AS ( - INSERT INTO events (title, location, event_time, description) - VALUES ($1, $2, $3, $4) - RETURNING eid, location, event_time, description, date_created, date_modified, title -) -SELECT v.eid, v.location, v.event_time, v.description, v.date_created, v.date_modified, v.title, v.org_ids FROM events_with_org_ids v -WHERE v.eid = (SELECT eid FROM updated) +INSERT INTO events (title, location, event_time, description) +VALUES ($1, $2, $3, $4) +RETURNING eid, location, event_time, description, date_created, date_modified, title ` type CreateEventParams struct { @@ -96,14 +92,14 @@ type CreateEventParams struct { Description pgtype.Text `json:"description"` } -func (q *Queries) CreateEvent(ctx context.Context, arg CreateEventParams) (EventsWithOrgID, error) { +func (q *Queries) CreateEvent(ctx context.Context, arg CreateEventParams) (Event, error) { row := q.db.QueryRow(ctx, createEvent, arg.Title, arg.Location, arg.EventTime, arg.Description, ) - var i EventsWithOrgID + var i Event err := row.Scan( &i.Eid, &i.Location, @@ -112,7 +108,6 @@ func (q *Queries) CreateEvent(ctx context.Context, arg CreateEventParams) (Event &i.DateCreated, &i.DateModified, &i.Title, - &i.OrgIds, ) return i, err } diff --git a/internal/handler/events.go b/internal/handler/events.go index afc732d..c7e96a6 100644 --- a/internal/handler/events.go +++ b/internal/handler/events.go @@ -90,7 +90,13 @@ func (h *Handler) CreateEvent(w http.ResponseWriter, r *http.Request) { return } - h.respondJSON(w, http.StatusCreated, toEventResponse(event)) + createdEvent, err := h.queries.GetEventByID(r.Context(), event.Eid) + if err != nil { + h.handleDBError(w, err) + return + } + + h.respondJSON(w, http.StatusCreated, toEventResponse(createdEvent)) } // GetEvent gets an event by ID From 4652849980a713245a2a80fa5d4f005eb78baace Mon Sep 17 00:00:00 2001 From: Shamik Karkhanis Date: Fri, 3 Apr 2026 13:18:50 -0400 Subject: [PATCH 4/6] tests pass --- internal/database/mocks/Querier.go | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/internal/database/mocks/Querier.go b/internal/database/mocks/Querier.go index f0070ab..f29ca02 100644 --- a/internal/database/mocks/Querier.go +++ b/internal/database/mocks/Querier.go @@ -83,22 +83,22 @@ func (_m *Querier) CreateBotToken(ctx context.Context, arg database.CreateBotTok } // CreateEvent provides a mock function with given fields: ctx, arg -func (_m *Querier) CreateEvent(ctx context.Context, arg database.CreateEventParams) (database.EventsWithOrgID, error) { +func (_m *Querier) CreateEvent(ctx context.Context, arg database.CreateEventParams) (database.Event, error) { ret := _m.Called(ctx, arg) if len(ret) == 0 { panic("no return value specified for CreateEvent") } - var r0 database.EventsWithOrgID + var r0 database.Event var r1 error - if rf, ok := ret.Get(0).(func(context.Context, database.CreateEventParams) (database.EventsWithOrgID, error)); ok { + if rf, ok := ret.Get(0).(func(context.Context, database.CreateEventParams) (database.Event, error)); ok { return rf(ctx, arg) } - if rf, ok := ret.Get(0).(func(context.Context, database.CreateEventParams) database.EventsWithOrgID); ok { + if rf, ok := ret.Get(0).(func(context.Context, database.CreateEventParams) database.Event); ok { r0 = rf(ctx, arg) } else { - r0 = ret.Get(0).(database.EventsWithOrgID) + r0 = ret.Get(0).(database.Event) } if rf, ok := ret.Get(1).(func(context.Context, database.CreateEventParams) error); ok { From 821a6211e98af67f6aafbaa86a08c18d54a238ad Mon Sep 17 00:00:00 2001 From: Shamik Karkhanis Date: Fri, 3 Apr 2026 13:31:06 -0400 Subject: [PATCH 5/6] tests pass again, turns out we need schema for benchmarks --- schema.sql | 91 ++++++++++++++++++++++++++++++ tests/benchmarks/benchmark_test.go | 2 +- tests/benchmarks/suite_test.go | 17 +++++- 3 files changed, 106 insertions(+), 4 deletions(-) create mode 100644 schema.sql diff --git a/schema.sql b/schema.sql new file mode 100644 index 0000000..aff59d4 --- /dev/null +++ b/schema.sql @@ -0,0 +1,91 @@ +-- schema.sql +-- Database Schema for CAPY (Club Assistant in Python) + +-- 1. ENUMs & Functions +CREATE TYPE user_role AS ENUM ('student', 'alumni', 'faculty', 'external'); + +CREATE OR REPLACE FUNCTION update_modified_column() +RETURNS TRIGGER AS $$ +BEGIN + NEW.date_modified = CURRENT_DATE; + RETURN NEW; +END; +$$ language 'plpgsql'; + +-- 2. Tables +CREATE TABLE IF NOT EXISTS users ( + uid UUID PRIMARY KEY DEFAULT gen_random_uuid(), + first_name TEXT NOT NULL, + last_name TEXT NOT NULL, + personal_email TEXT UNIQUE, + school_email TEXT UNIQUE, + phone TEXT, + grad_year INT, + role user_role DEFAULT 'student', + date_created DATE DEFAULT CURRENT_DATE, + date_modified DATE DEFAULT CURRENT_DATE +); + +CREATE TABLE IF NOT EXISTS organizations ( + oid UUID PRIMARY KEY DEFAULT gen_random_uuid(), + name TEXT NOT NULL, + date_created DATE DEFAULT CURRENT_DATE, + date_modified DATE DEFAULT CURRENT_DATE +); + +CREATE TABLE IF NOT EXISTS org_members ( + uid UUID REFERENCES users(uid) ON DELETE CASCADE, + oid UUID REFERENCES organizations(oid) ON DELETE CASCADE, + is_admin BOOLEAN DEFAULT FALSE, + date_joined DATE DEFAULT CURRENT_DATE, + last_active DATE DEFAULT CURRENT_DATE, + PRIMARY KEY (uid, oid) +); + +CREATE TABLE IF NOT EXISTS events ( + eid UUID PRIMARY KEY DEFAULT gen_random_uuid(), + location TEXT, + event_time TIMESTAMP, + description TEXT, + date_created DATE DEFAULT CURRENT_DATE, + date_modified DATE DEFAULT CURRENT_DATE +); + +CREATE TABLE IF NOT EXISTS event_hosting ( + eid UUID REFERENCES events(eid) ON DELETE CASCADE, + oid UUID REFERENCES organizations(oid) ON DELETE CASCADE, + PRIMARY KEY (eid, oid) +); + +CREATE TABLE IF NOT EXISTS event_registrations ( + uid UUID REFERENCES users(uid) ON DELETE CASCADE, + eid UUID REFERENCES events(eid) ON DELETE CASCADE, + is_attending BOOLEAN DEFAULT FALSE, + is_admin BOOLEAN DEFAULT FALSE, + date_registered DATE DEFAULT CURRENT_DATE, + PRIMARY KEY (uid, eid) +); + +-- 3. Bot Tokens (global access for M2M authentication) +CREATE TABLE IF NOT EXISTS bot_tokens ( + token_id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + token_hash TEXT NOT NULL, -- bcrypt hash of the token + name TEXT NOT NULL, -- human-readable name for the bot + created_by UUID NOT NULL REFERENCES users(uid), + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + last_used_at TIMESTAMP, + expires_at TIMESTAMP, -- NULL = never expires + is_active BOOLEAN DEFAULT TRUE +); + +CREATE INDEX IF NOT EXISTS idx_bot_tokens_active ON bot_tokens(is_active) WHERE is_active = TRUE; + +-- 4. Triggers +DROP TRIGGER IF EXISTS update_users_modtime ON users; +CREATE TRIGGER update_users_modtime BEFORE UPDATE ON users FOR EACH ROW EXECUTE FUNCTION update_modified_column(); + +DROP TRIGGER IF EXISTS update_orgs_modtime ON organizations; +CREATE TRIGGER update_orgs_modtime BEFORE UPDATE ON organizations FOR EACH ROW EXECUTE FUNCTION update_modified_column(); + +DROP TRIGGER IF EXISTS update_events_modtime ON events; +CREATE TRIGGER update_events_modtime BEFORE UPDATE ON events FOR EACH ROW EXECUTE FUNCTION update_modified_column(); diff --git a/tests/benchmarks/benchmark_test.go b/tests/benchmarks/benchmark_test.go index e38fb5b..525d1e0 100644 --- a/tests/benchmarks/benchmark_test.go +++ b/tests/benchmarks/benchmark_test.go @@ -31,7 +31,7 @@ func BenchmarkHealthEndpoint(b *testing.B) { b.ReportAllocs() for i := 0; i < b.N; i++ { - resp, err := benchClient.Get(benchServer.URL + "/health") + resp, err := benchClient.Get(benchServer.URL + "/api/health") if err != nil { b.Fatalf("failed to make request: %v", err) } diff --git a/tests/benchmarks/suite_test.go b/tests/benchmarks/suite_test.go index 66f574b..4f581fe 100644 --- a/tests/benchmarks/suite_test.go +++ b/tests/benchmarks/suite_test.go @@ -43,13 +43,12 @@ func TestMain(m *testing.M) { _, filename, _, _ := runtime.Caller(0) projectRoot := filepath.Join(filepath.Dir(filename), "../..") - schemaPath := filepath.Join(projectRoot, "schema.sql") + migrationsPath := filepath.Join(projectRoot, "migrations") - log.Printf("Using schema from: %s", schemaPath) + log.Printf("Using migrations from: %s", migrationsPath) pgContainer, err := postgres.Run(ctx, "postgres:16-alpine", - postgres.WithInitScripts(schemaPath), postgres.WithDatabase("bench_db"), postgres.WithUsername("bench"), postgres.WithPassword("bench"), @@ -82,6 +81,10 @@ func TestMain(m *testing.M) { log.Fatalf("failed to connect to database: %v", err) } + if err := database.RunMigrations(ctx, connStr, migrationsPath); err != nil { + log.Fatalf("failed to apply migrations: %v", err) + } + benchQueries = database.New(benchDB) setupTestData(ctx) @@ -135,6 +138,14 @@ func setupTestData(ctx context.Context) { } benchOrgID = org.Oid.String() + if err := benchQueries.AddOrgMember(ctx, database.AddOrgMemberParams{ + Uid: user.Uid, + Oid: org.Oid, + IsAdmin: pgtype.Bool{Bool: true, Valid: true}, + }); err != nil { + log.Fatalf("failed to add benchmark user as org admin: %v", err) + } + event, err := benchQueries.CreateEvent(ctx, database.CreateEventParams{ Location: pgtype.Text{String: "Bench Event", Valid: true}, EventTime: pgtype.Timestamp{Time: time.Now().Add(24 * time.Hour), Valid: true}, From 19cf65e38815c5e71a0810d8155334e6ffc475e3 Mon Sep 17 00:00:00 2001 From: Shamik Karkhanis Date: Fri, 3 Apr 2026 13:43:25 -0400 Subject: [PATCH 6/6] fixed integration tests --- internal/database/queries.sql | 18 ++++++++++++++++-- internal/database/queries.sql.go | 18 ++++++++++++++++-- internal/router/router.go | 3 +++ internal/testutils/container.go | 9 ++++++--- 4 files changed, 41 insertions(+), 7 deletions(-) diff --git a/internal/database/queries.sql b/internal/database/queries.sql index c519f1d..5aa6836 100644 --- a/internal/database/queries.sql +++ b/internal/database/queries.sql @@ -101,8 +101,22 @@ WITH updated AS ( WHERE eid = $1 RETURNING * ) -SELECT v.* FROM events_with_org_ids v -WHERE v.eid = $1; +SELECT + u.eid, + u.location, + u.event_time, + u.description, + u.date_created, + u.date_modified, + u.title, + COALESCE(hosts.org_ids, ARRAY[]::uuid[]) AS org_ids +FROM updated u +LEFT JOIN ( + SELECT eh.eid, ARRAY_AGG(eh.oid)::uuid[] AS org_ids + FROM event_hosting eh + WHERE eh.eid = $1 + GROUP BY eh.eid +) hosts ON hosts.eid = u.eid; -- name: DeleteEvent :exec DELETE FROM events WHERE eid = $1; diff --git a/internal/database/queries.sql.go b/internal/database/queries.sql.go index 9d93d34..a552607 100644 --- a/internal/database/queries.sql.go +++ b/internal/database/queries.sql.go @@ -968,8 +968,22 @@ WITH updated AS ( WHERE eid = $1 RETURNING eid, location, event_time, description, date_created, date_modified, title ) -SELECT v.eid, v.location, v.event_time, v.description, v.date_created, v.date_modified, v.title, v.org_ids FROM events_with_org_ids v -WHERE v.eid = $1 +SELECT + u.eid, + u.location, + u.event_time, + u.description, + u.date_created, + u.date_modified, + u.title, + COALESCE(hosts.org_ids, ARRAY[]::uuid[]) AS org_ids +FROM updated u +LEFT JOIN ( + SELECT eh.eid, ARRAY_AGG(eh.oid)::uuid[] AS org_ids + FROM event_hosting eh + WHERE eh.eid = $1 + GROUP BY eh.eid +) hosts ON hosts.eid = u.eid ` type UpdateEventParams struct { diff --git a/internal/router/router.go b/internal/router/router.go index e4522cc..969c40f 100644 --- a/internal/router/router.go +++ b/internal/router/router.go @@ -69,6 +69,9 @@ func New(h *handler.Handler, queries database.Querier, jwtSecret string, allowed r.Use(chimiddleware.RequestID) r.Use(middleware.CORS(allowedOrigins, h.Config.Env == "development")) + // Public link resolution alias. + r.Get("/r/{endpoint_url}", h.ResolveLink) + // API routes r.Route("/api", func(r chi.Router) { // Health check (public) diff --git a/internal/testutils/container.go b/internal/testutils/container.go index 8f507df..23af7c1 100644 --- a/internal/testutils/container.go +++ b/internal/testutils/container.go @@ -46,16 +46,15 @@ func SetupTestPostgres(t *testing.T) string { return connStr } -// SetupTestDB creates a fresh Postgres container, initializes schema.sql, and returns the connection pool. +// SetupTestDB creates a fresh Postgres container, applies all migrations, and returns the connection pool. func SetupTestDB(t *testing.T) *pgxpool.Pool { ctx := context.Background() _, filename, _, _ := runtime.Caller(0) projectRoot := filepath.Join(filepath.Dir(filename), "../..") - schemaPath := filepath.Join(projectRoot, "schema.sql") + migrationsPath := filepath.Join(projectRoot, "migrations") pgContainer, err := postgres.Run(ctx, "postgres:16-alpine", - postgres.WithInitScripts(schemaPath), postgres.WithDatabase("test_db"), postgres.WithUsername("test"), postgres.WithPassword("test"), @@ -79,6 +78,10 @@ func SetupTestDB(t *testing.T) *pgxpool.Pool { t.Fatalf("failed to get connection string: %v", err) } + if err := database.RunMigrations(ctx, connStr, migrationsPath); err != nil { + t.Fatalf("failed to apply migrations: %v", err) + } + pool, err := database.NewPool(ctx, connStr) if err != nil { t.Fatalf("failed to connect to database: %v", err)