diff --git a/prov/shm/src/smr_av.c b/prov/shm/src/smr_av.c index de12e152545..61e4344bde5 100644 --- a/prov/shm/src/smr_av.c +++ b/prov/shm/src/smr_av.c @@ -69,9 +69,12 @@ static void smr_map_cleanup(struct smr_map *map) { int64_t i; - for (i = 0; i < SMR_MAX_PEERS; i++) - smr_map_del(map, i); + for (i = 0; i < SMR_MAX_PEERS; i++) { + if (map->peers[i].peer.id < 0) + continue; + smr_map_del(map, i); + } ofi_rbmap_cleanup(&map->rbmap); } @@ -210,7 +213,6 @@ static int smr_av_remove(struct fid_av *av_fid, fi_addr_t *fi_addr, size_t count dlist_foreach(&util_av->ep_list, av_entry) { util_ep = container_of(av_entry, struct util_ep, av_entry); smr_ep = container_of(util_ep, struct smr_ep, util_ep); - smr_unmap_from_endpoint(smr_ep->region, id); if (smr_av->smr_map.num_peers > 0) smr_ep->region->max_sar_buf_per_peer = SMR_MAX_PEERS / diff --git a/prov/shm/src/smr_ep.c b/prov/shm/src/smr_ep.c index 8ad190711fb..dd3d7f53f07 100644 --- a/prov/shm/src/smr_ep.c +++ b/prov/shm/src/smr_ep.c @@ -223,7 +223,9 @@ int64_t smr_verify_peer(struct smr_ep *ep, fi_addr_t fi_addr) return id; if (!ep->region->map->peers[id].region) { + ofi_spin_lock(&ep->region->map->lock); ret = smr_map_to_region(&smr_prov, ep->region->map, id); + ofi_spin_unlock(&ep->region->map->lock); if (ret) return -1; } diff --git a/prov/shm/src/smr_progress.c b/prov/shm/src/smr_progress.c index c5315aa4b1f..5059f576eb2 100644 --- a/prov/shm/src/smr_progress.c +++ b/prov/shm/src/smr_progress.c @@ -878,7 +878,9 @@ static void smr_progress_connreq(struct smr_ep *ep, struct smr_cmd *cmd) peer_smr = smr_peer_region(ep->region, idx); if (!peer_smr) { + ofi_spin_lock(&ep->region->map->lock); ret = smr_map_to_region(&smr_prov, ep->region->map, idx); + ofi_spin_unlock(&ep->region->map->lock); if (ret) { FI_WARN(&smr_prov, FI_LOG_EP_CTRL, "Could not map peer region\n"); @@ -891,14 +893,11 @@ static void smr_progress_connreq(struct smr_ep *ep, struct smr_cmd *cmd) if (peer_smr->pid != (int) cmd->msg.hdr.data) { /* TODO track and update/complete in error any transfers * to or from old mapping - * - * TODO create smr_unmap_region - * this needs to close peer_smr->map->peers[idx].pid_fd - * This case will also return an unmapped region because the idx - * is valid but the region was unmapped */ - munmap(peer_smr, peer_smr->total_size); + ofi_spin_lock(&ep->region->map->lock); + smr_unmap_region(&smr_prov, ep->region->map, idx, false); smr_map_to_region(&smr_prov, ep->region->map, idx); + ofi_spin_unlock(&ep->region->map->lock); peer_smr = smr_peer_region(ep->region, idx); } diff --git a/prov/shm/src/smr_util.c b/prov/shm/src/smr_util.c index 2924ddaa6f2..0c5de80e2a0 100644 --- a/prov/shm/src/smr_util.c +++ b/prov/shm/src/smr_util.c @@ -367,16 +367,15 @@ int smr_map_to_region(const struct fi_provider *prov, struct smr_map *map, } pthread_mutex_unlock(&ep_list_lock); - ofi_spin_lock(&map->lock); if (peer_buf->region) - goto unlock; + return FI_SUCCESS; + assert(ofi_spin_held(&map->lock)); fd = shm_open(name, O_RDWR, S_IRUSR | S_IWUSR); if (fd < 0) { - ret = -errno; FI_WARN_ONCE(prov, FI_LOG_AV, "shm_open error: name %s errno %d\n", name, errno); - goto unlock; + return -errno; } memset(tmp, 0, sizeof(tmp)); @@ -437,8 +436,6 @@ int smr_map_to_region(const struct fi_provider *prov, struct smr_map *map, out: close(fd); -unlock: - ofi_spin_unlock(&map->lock); return ret; } @@ -448,6 +445,7 @@ void smr_map_to_endpoint(struct smr_region *region, int64_t id) struct smr_region *peer_smr; struct smr_peer_data *local_peers; + assert(ofi_spin_held(®ion->map->lock)); peer_smr = smr_peer_region(region, id); if (region->map->peers[id].peer.id < 0 || !peer_smr) return; @@ -479,32 +477,81 @@ void smr_map_to_endpoint(struct smr_region *region, int64_t id) return; } +void smr_unmap_region(const struct fi_provider *prov, struct smr_map *map, + int64_t peer_id, bool local) +{ + struct smr_region *peer_region; + struct smr_peer *peer; + struct util_ep *util_ep; + struct smr_ep *smr_ep; + struct smr_av *av; + int ret = 0; + + assert(ofi_spin_held(&map->lock)); + peer_region = map->peers[peer_id].region; + if (!peer_region) + return; + + peer = &map->peers[peer_id]; + av = container_of(map, struct smr_av, smr_map); + dlist_foreach_container(&av->util_av.ep_list, struct util_ep, util_ep, + av_entry) { + smr_ep = container_of(util_ep, struct smr_ep, util_ep); + smr_unmap_from_endpoint(smr_ep->region, peer_id); + } + + /* Don't unmap memory owned by this pid because the endpoint it belongs + * to might still be active. + */ + if (local) + return; + + if (map->flags & SMR_FLAG_HMEM_ENABLED) { + ret = ofi_hmem_host_unregister(peer_region); + if (ret) + FI_WARN(prov, FI_LOG_EP_CTRL, + "unable to unregister shm with iface\n"); + + if (peer->pid_fd != -1) { + close(peer->pid_fd); + peer->pid_fd = -1; + } + } + + munmap(peer_region, peer_region->total_size); + peer->region = NULL; +} + void smr_unmap_from_endpoint(struct smr_region *region, int64_t id) { struct smr_region *peer_smr; struct smr_peer_data *local_peers, *peer_peers; int64_t peer_id; - local_peers = smr_peer_data(region); if (region->map->peers[id].peer.id < 0) return; peer_smr = smr_peer_region(region, id); - peer_id = smr_peer_data(region)[id].addr.id; - + assert(peer_smr); peer_peers = smr_peer_data(peer_smr); + peer_id = smr_peer_data(region)[id].addr.id; peer_peers[peer_id].addr.id = -1; peer_peers[peer_id].name_sent = 0; + local_peers = smr_peer_data(region); ofi_xpmem_release(&local_peers[peer_id].xpmem); } void smr_exchange_all_peers(struct smr_region *region) { int64_t i; + + ofi_spin_lock(®ion->map->lock); for (i = 0; i < SMR_MAX_PEERS; i++) smr_map_to_endpoint(region, i); + + ofi_spin_unlock(®ion->map->lock); } int smr_map_add(const struct fi_provider *prov, struct smr_map *map, @@ -546,37 +593,24 @@ int smr_map_add(const struct fi_provider *prov, struct smr_map *map, void smr_map_del(struct smr_map *map, int64_t id) { - struct dlist_entry *entry; + struct smr_ep_name *name; + bool local = false; assert(id >= 0 && id < SMR_MAX_PEERS); - pthread_mutex_lock(&ep_list_lock); - entry = dlist_find_first_match(&ep_name_list, smr_match_name, - smr_no_prefix(map->peers[id].peer.name)); + dlist_foreach_container(&ep_name_list, struct smr_ep_name, name, entry) { + if (strcmp(name->name, map->peers[id].peer.name)) { + local = true; + break; + } + } pthread_mutex_unlock(&ep_list_lock); - ofi_spin_lock(&map->lock); - (void) ofi_rbmap_find_delete(&map->rbmap, - (void *) map->peers[id].peer.name); - + smr_unmap_region(&smr_prov, map, id, local); map->peers[id].fiaddr = FI_ADDR_NOTAVAIL; map->peers[id].peer.id = -1; map->num_peers--; - - if (!map->peers[id].region) - goto unlock; - - if (!entry) { - if (map->flags & SMR_FLAG_HMEM_ENABLED) { - if (map->peers[id].pid_fd != -1) - close(map->peers[id].pid_fd); - - (void) ofi_hmem_host_unregister(map->peers[id].region); - } - munmap(map->peers[id].region, map->peers[id].region->total_size); - map->peers[id].region = NULL; - } -unlock: + ofi_rbmap_find_delete(&map->rbmap, map->peers[id].peer.name); ofi_spin_unlock(&map->lock); } diff --git a/prov/shm/src/smr_util.h b/prov/shm/src/smr_util.h index c5bf8124873..7ed4e1e426f 100644 --- a/prov/shm/src/smr_util.h +++ b/prov/shm/src/smr_util.h @@ -356,6 +356,8 @@ void smr_cleanup(void); int smr_map_to_region(const struct fi_provider *prov, struct smr_map *map, int64_t id); void smr_map_to_endpoint(struct smr_region *region, int64_t id); +void smr_unmap_region(const struct fi_provider *prov, struct smr_map *map, + int64_t id, bool found); void smr_unmap_from_endpoint(struct smr_region *region, int64_t id); void smr_exchange_all_peers(struct smr_region *region); int smr_map_add(const struct fi_provider *prov, struct smr_map *map,