Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,6 @@ profile.cov
# Go workspace file
go.work
go.work.sum

# JetBrains Intellij IDE
.idea/
4 changes: 2 additions & 2 deletions cmd/launcher/web/api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ func corsWithArgs(frontendAddress string) func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Access-Control-Allow-Origin", frontendAddress)
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, PATCH, OPTIONS")
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization")
if r.Method == "OPTIONS" {
w.WriteHeader(http.StatusOK)
Expand All @@ -77,7 +77,7 @@ func (a *apiLauncher) SetupSubrouters(router *mux.Router, config *launcher.Confi
corsHandler := corsWithArgs(a.config.frontendAddress)(apiHandler)

// Register it at the /api/ path
router.Methods("GET", "POST", "DELETE", "OPTIONS").PathPrefix("/api/").Handler(
router.Methods("GET", "POST", "DELETE", "PATCH", "OPTIONS").PathPrefix("/api/").Handler(
http.StripPrefix("/api", corsHandler),
)

Expand Down
46 changes: 46 additions & 0 deletions server/adkrest/controllers/sessions.go
Original file line number Diff line number Diff line change
Expand Up @@ -162,3 +162,49 @@ func (c *SessionsAPIController) ListSessionsHandler(rw http.ResponseWriter, req
}
EncodeJSONResponse(sessions, http.StatusOK, rw)
}

// UpdateSessionHandler handles updating a session's state, specifically it performs a PATCH.
func (c *SessionsAPIController) UpdateSessionHandler(rw http.ResponseWriter, req *http.Request) {
params := mux.Vars(req)
sessionID, err := models.SessionIDFromHTTPParameters(params)
if err != nil {
http.Error(rw, err.Error(), http.StatusBadRequest)
return
}
if sessionID.ID == "" {
http.Error(rw, "session_id parameter is required", http.StatusBadRequest)
return
}

patchRequest := models.PatchSessionStateDeltaRequest{}
if err := json.NewDecoder(req.Body).Decode(&patchRequest); err != nil {
http.Error(rw, err.Error(), http.StatusBadRequest)
return
}

// Normalize directives (e.g. delete) to nil values for the service layer
normalizedDelta, err := models.NormalizeStateDelta(patchRequest.StateDelta)
if err != nil {
http.Error(rw, err.Error(), http.StatusBadRequest)
return
}

// Use PatchState to update state without appending an event
patchResp, err := c.service.PatchState(req.Context(), &session.PatchStateRequest{
AppName: sessionID.AppName,
UserID: sessionID.UserID,
SessionID: sessionID.ID,
StateDelta: normalizedDelta,
})
if err != nil {
http.Error(rw, err.Error(), http.StatusInternalServerError)
return
}

respSession, err := models.FromSession(patchResp.Session)
if err != nil {
http.Error(rw, err.Error(), http.StatusInternalServerError)
return
}
EncodeJSONResponse(respSession, http.StatusOK, rw)
}
32 changes: 32 additions & 0 deletions server/adkrest/internal/fakes/testsessionservice.go
Original file line number Diff line number Diff line change
Expand Up @@ -185,4 +185,36 @@ func (s *FakeSessionService) AppendEvent(ctx context.Context, curSession session
return nil
}

func (s *FakeSessionService) PatchState(ctx context.Context, req *session.PatchStateRequest) (*session.PatchStateResponse, error) {
key := SessionKey{
AppName: req.AppName,
UserID: req.UserID,
SessionID: req.SessionID,
}
testSession, ok := s.Sessions[key]
if !ok {
return nil, fmt.Errorf("session %q not found", req.SessionID)
}

// Apply state delta
if testSession.SessionState == nil {
testSession.SessionState = make(TestState)
}
for k, v := range req.StateDelta {
if v == nil {
delete(testSession.SessionState, k)
} else {
testSession.SessionState[k] = v
}
}

// Update timestamp
testSession.UpdatedAt = time.Now()
s.Sessions[key] = testSession

return &session.PatchStateResponse{
Session: &testSession,
}, nil
}

var _ session.Service = (*FakeSessionService)(nil)
64 changes: 64 additions & 0 deletions server/adkrest/internal/models/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,16 @@ import (
"google.golang.org/adk/session"
)

// State delta directive constants
const (
// stateUpdateKey is the special key used in state delta directives
// to indicate a patch operation (e.g., delete).
stateUpdateKey = "$adk_state_update"

// stateUpdateDelete is the directive value indicating a key should be deleted.
stateUpdateDelete = "delete"
)

// Session represents an agent's session.
type Session struct {
ID string `json:"id"`
Expand All @@ -38,6 +48,10 @@ type CreateSessionRequest struct {
Events []Event `json:"events"`
}

type PatchSessionStateDeltaRequest struct {
StateDelta map[string]any `json:"stateDelta"`
}

type SessionID struct {
ID string `mapstructure:"session_id,optional"`
AppName string `mapstructure:"app_name,required"`
Expand Down Expand Up @@ -105,3 +119,53 @@ func (s Session) Validate() error {
}
return nil
}

// NormalizeStateDelta processes state delta directives and converts them
// into a normalized representation suitable for the service layer.
// Delete directives ({"$adk_state_update": "delete"}) are converted to nil values.
// Returns a new map with normalized values.
func NormalizeStateDelta(stateDelta map[string]any) (map[string]any, error) {
normalized := make(map[string]any, len(stateDelta))
for key, value := range stateDelta {
// Check if value is a directive (map with special key)
directive, isDirective := value.(map[string]any)
if isDirective {
// Check if this map contains a state update directive
updateValue, hasDirective := directive[stateUpdateKey]
if hasDirective {
normalizedValue, err := processDirective(key, updateValue)
if err != nil {
return nil, err
}
normalized[key] = normalizedValue
continue
}
// else: it's a normal map value, fall through and set it as-is
}

// Normal value (including normal maps): keep it directly.
normalized[key] = value
}

return normalized, nil
}

// processDirective handles a state update directive and returns the normalized value.
func processDirective(key string, updateValue any) (any, error) {
updateStr, ok := updateValue.(string)
if !ok {
return nil, fmt.Errorf(
"invalid directive value type for key %q: expected string, got %T",
key,
updateValue,
)
}

switch updateStr {
case stateUpdateDelete:
// Delete directive: return nil to indicate deletion
return nil, nil
default:
return nil, fmt.Errorf("unknown state update directive %q for key %q", updateStr, key)
}
}
6 changes: 6 additions & 0 deletions server/adkrest/internal/routers/sessions.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,5 +63,11 @@ func (r *SessionsAPIRouter) Routes() Routes {
Pattern: "/apps/{app_name}/users/{user_id}/sessions",
HandlerFunc: r.sessionController.ListSessionsHandler,
},
Route{
Name: "UpdateSession",
Methods: []string{http.MethodPatch},
Pattern: "/apps/{app_name}/users/{user_id}/sessions/{session_id}",
HandlerFunc: r.sessionController.UpdateSessionHandler,
},
}
}
121 changes: 121 additions & 0 deletions session/database/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,127 @@ func (s *databaseService) AppendEvent(ctx context.Context, curSession session.Se
return sess.appendEvent(event)
}

// PatchState updates a session's state without appending an event
func (s *databaseService) PatchState(ctx context.Context, req *session.PatchStateRequest) (*session.PatchStateResponse, error) {
appName, userID, sessionID := req.AppName, req.UserID, req.SessionID
if appName == "" || userID == "" || sessionID == "" {
return nil, fmt.Errorf("app_name, user_id, session_id are required, got app_name: %q, user_id: %q, session_id: %q", appName, userID, sessionID)
}

var responseSession *localSession

err := s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
// Fetch the session from storage
var storageSess storageSession
err := tx.Where(&storageSession{AppName: appName, UserID: userID, ID: sessionID}).
First(&storageSess).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return fmt.Errorf("session %q not found", sessionID)
}
return fmt.Errorf("failed to get session: %w", err)
}

// Fetch App and User states
storageApp, err := fetchStorageAppState(tx, appName)
if err != nil {
return err
}
storageUser, err := fetchStorageUserState(tx, appName, userID)
if err != nil {
return err
}

// Extract state deltas
appDelta, userDelta, sessionDelta := extractStateDeltas(req.StateDelta)

// Apply app state delta
if len(appDelta) > 0 {
for key, value := range appDelta {
if value == nil {
delete(storageApp.State, key)
} else {
storageApp.State[key] = value
}
}
if err := tx.Save(&storageApp).Error; err != nil {
return fmt.Errorf("failed to save app state: %w", err)
}
}

// Apply user state delta
if len(userDelta) > 0 {
for key, value := range userDelta {
if value == nil {
delete(storageUser.State, key)
} else {
storageUser.State[key] = value
}
}
if err := tx.Save(&storageUser).Error; err != nil {
return fmt.Errorf("failed to save user state: %w", err)
}
}

// Apply session state delta
if len(sessionDelta) > 0 {
for key, value := range sessionDelta {
if value == nil {
delete(storageSess.State, key)
} else {
storageSess.State[key] = value
}
}
}

// Update timestamp
storageSess.UpdateTime = time.Now()

// Save the session
if err := tx.Save(&storageSess).Error; err != nil {
return fmt.Errorf("failed to save session state: %w", err)
}

// Create response session
responseSession, err = createSessionFromStorageSession(&storageSess)
if err != nil {
return fmt.Errorf("failed to map storage object: %w", err)
}
responseSession.state = mergeStates(storageApp.State, storageUser.State, responseSession.state)

// Fetch events for the response
var storageEvents []storageEvent
if err := tx.Model(&storageEvent{}).
Where("app_name = ?", appName).
Where("user_id = ?", userID).
Where("session_id = ?", sessionID).
Order("timestamp ASC").
Find(&storageEvents).Error; err != nil {
return fmt.Errorf("database error while fetching events: %w", err)
}

responseEvents := make([]*session.Event, 0, len(storageEvents))
for i := range storageEvents {
evt, err := createEventFromStorageEvent(&storageEvents[i])
if err != nil {
return fmt.Errorf("failed to map storage event: %w", err)
}
responseEvents = append(responseEvents, evt)
}
responseSession.events = responseEvents

return nil
})

if err != nil {
return nil, err
}

return &session.PatchStateResponse{
Session: responseSession,
}, nil
}

// applyEvent fetches the session, validates it, applies state changes from an
// event, and saves the event atomically.
func (s *databaseService) applyEvent(ctx context.Context, session *localSession, event *session.Event) error {
Expand Down
50 changes: 50 additions & 0 deletions session/inmemory.go
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,56 @@ func (s *inMemoryService) AppendEvent(ctx context.Context, curSession Session, e
return nil
}

// PatchState updates a session's state without appending an event.
func (s *inMemoryService) PatchState(ctx context.Context, req *PatchStateRequest) (*PatchStateResponse, error) {
appName, userID, sessionID := req.AppName, req.UserID, req.SessionID
if appName == "" || userID == "" || sessionID == "" {
return nil, fmt.Errorf("app_name, user_id, session_id are required, got app_name: %q, user_id: %q, session_id: %q", appName, userID, sessionID)
}

s.mu.Lock()
defer s.mu.Unlock()

id := id{
appName: appName,
userID: userID,
sessionID: sessionID,
}

storedSession, ok := s.sessions.Get(id.Encode())
if !ok {
return nil, fmt.Errorf("session %q not found", sessionID)
}

// Apply state delta: extract app/user/session deltas
appDelta, userDelta, sessionDelta := sessionutils.ExtractStateDeltas(req.StateDelta)

// Update app and user state
s.updateAppState(appDelta, appName)
s.updateUserState(userDelta, appName, userID)

// Apply session state delta: add/overwrite keys, delete keys with nil value
for key, value := range sessionDelta {
if value == nil {
delete(storedSession.state, key)
} else {
storedSession.state[key] = value
}
}

// Update timestamp
storedSession.updatedAt = time.Now()

// Return a copy of the updated session
copiedSession := copySessionWithoutStateAndEvents(storedSession)
copiedSession.state = s.mergeStates(storedSession.state, appName, userID)
copiedSession.events = slices.Clone(storedSession.events)

return &PatchStateResponse{
Session: copiedSession,
}, nil
}

func (s *inMemoryService) updateAppState(appDelta stateMap, appName string) stateMap {
innerMap, ok := s.appState[appName]
if !ok {
Expand Down
Loading